sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
54
54
  TransferBackend,
55
55
  get_kv_class,
56
56
  )
57
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
57
+ from sglang.srt.hf_transformers_utils import (
58
+ get_processor,
59
+ get_tokenizer,
60
+ get_tokenizer_from_processor,
61
+ )
58
62
  from sglang.srt.managers.io_struct import (
59
63
  AbortReq,
60
64
  BatchEmbeddingOut,
@@ -86,6 +90,8 @@ from sglang.srt.managers.io_struct import (
86
90
  ResumeMemoryOccupationReqInput,
87
91
  ResumeMemoryOccupationReqOutput,
88
92
  SessionParams,
93
+ SlowDownReqInput,
94
+ SlowDownReqOutput,
89
95
  TokenizedEmbeddingReqInput,
90
96
  TokenizedGenerateReqInput,
91
97
  UpdateWeightFromDiskReqInput,
@@ -161,17 +167,7 @@ class TokenizerManager:
161
167
  # Read model args
162
168
  self.model_path = server_args.model_path
163
169
  self.served_model_name = server_args.served_model_name
164
- self.model_config = ModelConfig(
165
- server_args.model_path,
166
- trust_remote_code=server_args.trust_remote_code,
167
- revision=server_args.revision,
168
- context_length=server_args.context_length,
169
- model_override_args=server_args.json_model_override_args,
170
- is_embedding=server_args.is_embedding,
171
- enable_multimodal=server_args.enable_multimodal,
172
- dtype=server_args.dtype,
173
- quantization=server_args.quantization,
174
- )
170
+ self.model_config = ModelConfig.from_server_args(server_args)
175
171
 
176
172
  self.is_generation = self.model_config.is_generation
177
173
  self.is_image_gen = self.model_config.is_image_gen
@@ -199,7 +195,7 @@ class TokenizerManager:
199
195
  self.tokenizer = self.processor = None
200
196
  else:
201
197
  self.processor = _processor
202
- self.tokenizer = self.processor.tokenizer
198
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
203
199
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
204
200
  else:
205
201
  self.mm_processor = get_dummy_processor()
@@ -265,6 +261,9 @@ class TokenizerManager:
265
261
  self.resume_memory_occupation_communicator = _Communicator(
266
262
  self.send_to_scheduler, server_args.dp_size
267
263
  )
264
+ self.slow_down_communicator = _Communicator(
265
+ self.send_to_scheduler, server_args.dp_size
266
+ )
268
267
  self.flush_cache_communicator = _Communicator(
269
268
  self.send_to_scheduler, server_args.dp_size
270
269
  )
@@ -318,6 +317,10 @@ class TokenizerManager:
318
317
  ResumeMemoryOccupationReqOutput,
319
318
  self.resume_memory_occupation_communicator.handle_recv,
320
319
  ),
320
+ (
321
+ SlowDownReqOutput,
322
+ self.slow_down_communicator.handle_recv,
323
+ ),
321
324
  (
322
325
  FlushCacheReqOutput,
323
326
  self.flush_cache_communicator.handle_recv,
@@ -876,6 +879,14 @@ class TokenizerManager:
876
879
  self.auto_create_handle_loop()
877
880
  await self.resume_memory_occupation_communicator(obj)
878
881
 
882
+ async def slow_down(
883
+ self,
884
+ obj: SlowDownReqInput,
885
+ request: Optional[fastapi.Request] = None,
886
+ ):
887
+ self.auto_create_handle_loop()
888
+ await self.slow_down_communicator(obj)
889
+
879
890
  async def open_session(
880
891
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
881
892
  ):
@@ -15,12 +15,17 @@
15
15
 
16
16
  import logging
17
17
  import threading
18
- from typing import Optional, Tuple
18
+ from typing import Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
 
22
22
  from sglang.srt.configs.model_config import ModelConfig
23
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
23
+ from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
24
+ from sglang.srt.hf_transformers_utils import (
25
+ get_processor,
26
+ get_tokenizer,
27
+ get_tokenizer_from_processor,
28
+ )
24
29
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
25
30
  from sglang.srt.managers.io_struct import (
26
31
  GetWeightsByNameReqInput,
@@ -31,7 +36,7 @@ from sglang.srt.managers.io_struct import (
31
36
  )
32
37
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
33
38
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
40
  from sglang.srt.model_executor.model_runner import ModelRunner
36
41
  from sglang.srt.server_args import ServerArgs
37
42
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
@@ -47,6 +52,7 @@ class TpModelWorker:
47
52
  server_args: ServerArgs,
48
53
  gpu_id: int,
49
54
  tp_rank: int,
55
+ pp_rank: int,
50
56
  dp_rank: Optional[int],
51
57
  nccl_port: int,
52
58
  is_draft_worker: bool = False,
@@ -54,30 +60,29 @@ class TpModelWorker:
54
60
  token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
55
61
  ):
56
62
  # Parse args
63
+ self.tp_size = server_args.tp_size
57
64
  self.tp_rank = tp_rank
65
+ self.pp_rank = pp_rank
58
66
 
59
67
  # Init model and tokenizer
60
- self.model_config = ModelConfig(
61
- (
68
+ self.model_config = ModelConfig.from_server_args(
69
+ server_args,
70
+ model_path=(
62
71
  server_args.model_path
63
72
  if not is_draft_worker
64
73
  else server_args.speculative_draft_model_path
65
74
  ),
66
- trust_remote_code=server_args.trust_remote_code,
67
- revision=server_args.revision,
68
- context_length=server_args.context_length,
69
- model_override_args=server_args.json_model_override_args,
70
- is_embedding=server_args.is_embedding,
71
- enable_multimodal=server_args.enable_multimodal,
72
- dtype=server_args.dtype,
73
- quantization=server_args.quantization,
75
+ is_draft_model=is_draft_worker,
74
76
  )
77
+
75
78
  self.model_runner = ModelRunner(
76
79
  model_config=self.model_config,
77
80
  mem_fraction_static=server_args.mem_fraction_static,
78
81
  gpu_id=gpu_id,
79
82
  tp_rank=tp_rank,
80
83
  tp_size=server_args.tp_size,
84
+ pp_rank=pp_rank,
85
+ pp_size=server_args.pp_size,
81
86
  nccl_port=nccl_port,
82
87
  server_args=server_args,
83
88
  is_draft_worker=is_draft_worker,
@@ -94,7 +99,7 @@ class TpModelWorker:
94
99
  trust_remote_code=server_args.trust_remote_code,
95
100
  revision=server_args.revision,
96
101
  )
97
- self.tokenizer = self.processor.tokenizer
102
+ self.tokenizer = get_tokenizer_from_processor(self.processor)
98
103
  else:
99
104
  self.tokenizer = get_tokenizer(
100
105
  server_args.tokenizer_path,
@@ -104,6 +109,10 @@ class TpModelWorker:
104
109
  )
105
110
  self.device = self.model_runner.device
106
111
 
112
+ # Init nccl groups
113
+ self.pp_group = get_pp_group()
114
+ self.world_group = get_world_group()
115
+
107
116
  # Profile number of tokens
108
117
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
109
118
  self.max_prefill_tokens = server_args.max_prefill_tokens
@@ -129,8 +138,9 @@ class TpModelWorker:
129
138
  # Sync random seed across TP workers
130
139
  self.random_seed = broadcast_pyobj(
131
140
  [server_args.random_seed],
132
- self.tp_rank,
133
- self.model_runner.tp_group.cpu_group,
141
+ self.tp_size * self.pp_rank + tp_rank,
142
+ self.world_group.cpu_group,
143
+ src=self.world_group.ranks[0],
134
144
  )[0]
135
145
  set_random_seed(self.random_seed)
136
146
 
@@ -155,11 +165,14 @@ class TpModelWorker:
155
165
  def get_pad_input_ids_func(self):
156
166
  return getattr(self.model_runner.model, "pad_input_ids", None)
157
167
 
158
- def get_tp_cpu_group(self):
159
- return self.model_runner.tp_group.cpu_group
168
+ def get_tp_group(self):
169
+ return self.model_runner.tp_group
170
+
171
+ def get_attention_tp_group(self):
172
+ return self.model_runner.attention_tp_group
160
173
 
161
174
  def get_attention_tp_cpu_group(self):
162
- return self.model_runner.attention_tp_group.cpu_group
175
+ return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
163
176
 
164
177
  def get_memory_pool(self):
165
178
  return (
@@ -171,19 +184,38 @@ class TpModelWorker:
171
184
  self,
172
185
  model_worker_batch: ModelWorkerBatch,
173
186
  skip_sample: bool = False,
174
- ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
187
+ ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
175
188
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
176
- logits_output = self.model_runner.forward(forward_batch)
177
189
 
178
- if model_worker_batch.launch_done is not None:
179
- model_worker_batch.launch_done.set()
190
+ pp_proxy_tensors = None
191
+ if not self.pp_group.is_first_rank:
192
+ pp_proxy_tensors = PPProxyTensors(
193
+ self.pp_group.recv_tensor_dict(
194
+ all_gather_group=self.get_attention_tp_group()
195
+ )
196
+ )
180
197
 
181
- if skip_sample:
182
- next_token_ids = None
183
- else:
184
- next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
198
+ if self.pp_group.is_last_rank:
199
+ logits_output = self.model_runner.forward(
200
+ forward_batch, pp_proxy_tensors=pp_proxy_tensors
201
+ )
202
+ if model_worker_batch.launch_done is not None:
203
+ model_worker_batch.launch_done.set()
204
+
205
+ if skip_sample:
206
+ next_token_ids = None
207
+ else:
208
+ next_token_ids = self.model_runner.sample(
209
+ logits_output, model_worker_batch
210
+ )
185
211
 
186
- return logits_output, next_token_ids
212
+ return logits_output, next_token_ids
213
+ else:
214
+ pp_proxy_tensors = self.model_runner.forward(
215
+ forward_batch,
216
+ pp_proxy_tensors=pp_proxy_tensors,
217
+ )
218
+ return pp_proxy_tensors.tensors, None
187
219
 
188
220
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
189
221
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -56,11 +56,14 @@ class TpModelWorkerClient:
56
56
  server_args: ServerArgs,
57
57
  gpu_id: int,
58
58
  tp_rank: int,
59
+ pp_rank: int,
59
60
  dp_rank: Optional[int],
60
61
  nccl_port: int,
61
62
  ):
62
63
  # Load the model
63
- self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
64
+ self.worker = TpModelWorker(
65
+ server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
66
+ )
64
67
  self.max_running_requests = self.worker.max_running_requests
65
68
  self.device = self.worker.device
66
69
  self.gpu_id = gpu_id
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
91
94
  def get_pad_input_ids_func(self):
92
95
  return self.worker.get_pad_input_ids_func()
93
96
 
94
- def get_tp_cpu_group(self):
95
- return self.worker.get_tp_cpu_group()
97
+ def get_tp_group(self):
98
+ return self.worker.get_tp_group()
99
+
100
+ def get_attention_tp_group(self):
101
+ return self.worker.get_attention_tp_group()
96
102
 
97
103
  def get_attention_tp_cpu_group(self):
98
104
  return self.worker.get_attention_tp_cpu_group()
@@ -24,9 +24,11 @@ class ChunkCache(BasePrefixCache):
24
24
  self,
25
25
  req_to_token_pool: ReqToTokenPool,
26
26
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
+ page_size: int,
27
28
  ):
28
29
  self.req_to_token_pool = req_to_token_pool
29
30
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31
+ self.page_size = page_size
30
32
 
31
33
  def reset(self):
32
34
  pass
@@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache):
214
214
  layer_num: int,
215
215
  device: str,
216
216
  enable_memory_saver: bool,
217
+ start_layer: Optional[int] = None,
218
+ end_layer: Optional[int] = None,
217
219
  ):
218
220
  self.size = size
219
221
  self.page_size = page_size
@@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache):
232
234
  self.head_dim = head_dim
233
235
  self.layer_num = layer_num
234
236
  self._create_buffers()
237
+ self.start_layer = start_layer or 0
238
+ self.end_layer = end_layer or layer_num - 1
235
239
 
236
240
  self.layer_transfer_counter = None
237
241
  self.capture_mode = False
@@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache):
281
285
 
282
286
  # for disagg
283
287
  def get_contiguous_buf_infos(self):
288
+ # layer_num x [seq_len, head_num, head_dim]
289
+ # layer_num x [page_num, page_size, head_num, head_dim]
284
290
  kv_data_ptrs = [
285
291
  self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
286
292
  ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
@@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache):
320
326
  # transfer prepared data from host to device
321
327
  flat_data = flat_data.to(device=self.device, non_blocking=False)
322
328
  k_data, v_data = flat_data[0], flat_data[1]
323
- self.k_buffer[layer_id][indices] = k_data
324
- self.v_buffer[layer_id][indices] = v_data
329
+ self.k_buffer[layer_id - self.start_layer][indices] = k_data
330
+ self.v_buffer[layer_id - self.start_layer][indices] = v_data
325
331
 
326
332
  def get_key_buffer(self, layer_id: int):
327
333
  if self.layer_transfer_counter is not None:
328
- self.layer_transfer_counter.wait_until(layer_id)
334
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
329
335
 
330
336
  if self.store_dtype != self.dtype:
331
- return self.k_buffer[layer_id].view(self.dtype)
332
- return self.k_buffer[layer_id]
337
+ return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
338
+ return self.k_buffer[layer_id - self.start_layer]
333
339
 
334
340
  def get_value_buffer(self, layer_id: int):
335
341
  if self.layer_transfer_counter is not None:
336
- self.layer_transfer_counter.wait_until(layer_id)
342
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
337
343
 
338
344
  if self.store_dtype != self.dtype:
339
- return self.v_buffer[layer_id].view(self.dtype)
340
- return self.v_buffer[layer_id]
345
+ return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
346
+ return self.v_buffer[layer_id - self.start_layer]
341
347
 
342
348
  def get_kv_buffer(self, layer_id: int):
343
349
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -368,13 +374,13 @@ class MHATokenToKVPool(KVCache):
368
374
  # Overlap the copy of K and V cache for small batch size
369
375
  current_stream = self.device_module.current_stream()
370
376
  self.alt_stream.wait_stream(current_stream)
377
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
371
378
  with self.device_module.stream(self.alt_stream):
372
- self.k_buffer[layer_id][loc] = cache_k
373
- self.v_buffer[layer_id][loc] = cache_v
379
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
374
380
  current_stream.wait_stream(self.alt_stream)
375
381
  else:
376
- self.k_buffer[layer_id][loc] = cache_k
377
- self.v_buffer[layer_id][loc] = cache_v
382
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
383
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
378
384
 
379
385
 
380
386
  @torch.compile
@@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache):
484
490
  layer_num: int,
485
491
  device: str,
486
492
  enable_memory_saver: bool,
493
+ start_layer: Optional[int] = None,
494
+ end_layer: Optional[int] = None,
487
495
  ):
488
496
  self.size = size
489
497
  self.page_size = page_size
@@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache):
497
505
  self.kv_lora_rank = kv_lora_rank
498
506
  self.qk_rope_head_dim = qk_rope_head_dim
499
507
  self.layer_num = layer_num
508
+ self.start_layer = start_layer or 0
509
+ self.end_layer = end_layer or layer_num - 1
500
510
 
501
511
  memory_saver_adapter = TorchMemorySaverAdapter.create(
502
512
  enable=enable_memory_saver
@@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache):
540
550
 
541
551
  def get_key_buffer(self, layer_id: int):
542
552
  if self.layer_transfer_counter is not None:
543
- self.layer_transfer_counter.wait_until(layer_id)
553
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
544
554
 
545
555
  if self.store_dtype != self.dtype:
546
- return self.kv_buffer[layer_id].view(self.dtype)
547
- return self.kv_buffer[layer_id]
556
+ return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
557
+ return self.kv_buffer[layer_id - self.start_layer]
548
558
 
549
559
  def get_value_buffer(self, layer_id: int):
550
560
  if self.layer_transfer_counter is not None:
551
- self.layer_transfer_counter.wait_until(layer_id)
561
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
552
562
 
553
563
  if self.store_dtype != self.dtype:
554
- return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
555
- return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
564
+ return self.kv_buffer[layer_id - self.start_layer][
565
+ ..., : self.kv_lora_rank
566
+ ].view(self.dtype)
567
+ return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
556
568
 
557
569
  def get_kv_buffer(self, layer_id: int):
558
570
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache):
568
580
  if cache_k.dtype != self.dtype:
569
581
  cache_k = cache_k.to(self.dtype)
570
582
  if self.store_dtype != self.dtype:
571
- self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
583
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
584
+ self.store_dtype
585
+ )
572
586
  else:
573
- self.kv_buffer[layer_id][loc] = cache_k
587
+ self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
574
588
 
575
589
  def set_mla_kv_buffer(
576
590
  self,
@@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache):
605
619
  def transfer_per_layer(self, indices, flat_data, layer_id):
606
620
  # transfer prepared data from host to device
607
621
  flat_data = flat_data.to(device=self.device, non_blocking=False)
608
- self.kv_buffer[layer_id][indices] = flat_data
622
+ self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
609
623
 
610
624
 
611
625
  class DoubleSparseTokenToKVPool(KVCache):
@@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache):
620
634
  device: str,
621
635
  heavy_channel_num: int,
622
636
  enable_memory_saver: bool,
637
+ start_layer: Optional[int] = None,
638
+ end_layer: Optional[int] = None,
623
639
  ):
624
640
  self.size = size
625
641
  self.page_size = page_size
@@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache):
657
673
  for _ in range(layer_num)
658
674
  ]
659
675
 
676
+ self.start_layer = start_layer or 0
677
+ self.end_layer = end_layer or layer_num - 1
678
+
660
679
  def get_key_buffer(self, layer_id: int):
661
- return self.k_buffer[layer_id]
680
+ return self.k_buffer[layer_id - self.start_layer]
662
681
 
663
682
  def get_value_buffer(self, layer_id: int):
664
- return self.v_buffer[layer_id]
683
+ return self.v_buffer[layer_id - self.start_layer]
665
684
 
666
685
  def get_label_buffer(self, layer_id: int):
667
- return self.label_buffer[layer_id]
686
+ return self.label_buffer[layer_id - self.start_layer]
668
687
 
669
688
  def get_kv_buffer(self, layer_id: int):
670
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
689
+ return (
690
+ self.k_buffer[layer_id - self.start_layer],
691
+ self.v_buffer[layer_id - self.start_layer],
692
+ )
671
693
 
672
694
  def set_kv_buffer(
673
695
  self,
@@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache):
679
701
  ):
680
702
  # NOTE(Andy): ignore the dtype check
681
703
  layer_id = layer.layer_id
682
- self.k_buffer[layer_id][loc] = cache_k
683
- self.v_buffer[layer_id][loc] = cache_v
684
- self.label_buffer[layer_id][loc] = cache_label
704
+ self.k_buffer[layer_id - self.start_layer][loc] = cache_k
705
+ self.v_buffer[layer_id - self.start_layer][loc] = cache_v
706
+ self.label_buffer[layer_id - self.start_layer][loc] = cache_label
685
707
 
686
708
  def get_flat_data(self, indices):
687
709
  pass
@@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache):
930
952
  return self.kv_buffer[:, :, indices]
931
953
 
932
954
  def get_flat_data_by_layer(self, indices, layer_id):
933
- return self.kv_buffer[:, layer_id, indices]
955
+ return self.kv_buffer[:, layer_id - self.start_layer, indices]
934
956
 
935
957
  def assign_flat_data(self, indices, flat_data):
936
958
  self.kv_buffer[:, :, indices] = flat_data
@@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache):
955
977
  for i in range(len(device_indices_cpu)):
956
978
  h_index = host_indices[i * self.page_size]
957
979
  d_index = device_indices_cpu[i]
958
- device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
959
- self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
980
+ device_pool.k_buffer[layer_id - self.start_layer][
981
+ d_index : d_index + self.page_size
982
+ ].copy_(
983
+ self.kv_buffer[
984
+ 0, layer_id - self.start_layer, h_index : h_index + self.page_size
985
+ ],
960
986
  non_blocking=True,
961
987
  )
962
- device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
963
- self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
988
+ device_pool.v_buffer[layer_id - self.start_layer][
989
+ d_index : d_index + self.page_size
990
+ ].copy_(
991
+ self.kv_buffer[
992
+ 1, layer_id - self.start_layer, h_index : h_index + self.page_size
993
+ ],
964
994
  non_blocking=True,
965
995
  )
966
996
 
@@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache):
1015
1045
  return self.kv_buffer[:, indices]
1016
1046
 
1017
1047
  def get_flat_data_by_layer(self, indices, layer_id):
1018
- return self.kv_buffer[layer_id, indices]
1048
+ return self.kv_buffer[layer_id - self.start_layer, indices]
1019
1049
 
1020
1050
  def assign_flat_data(self, indices, flat_data):
1021
1051
  self.kv_buffer[:, indices] = flat_data
@@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache):
1036
1066
  for i in range(len(device_indices_cpu)):
1037
1067
  h_index = host_indices[i * self.page_size]
1038
1068
  d_index = device_indices_cpu[i]
1039
- device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
1040
- self.kv_buffer[layer_id, h_index : h_index + self.page_size],
1069
+ device_pool.kv_buffer[layer_id - self.start_layer][
1070
+ d_index : d_index + self.page_size
1071
+ ].copy_(
1072
+ self.kv_buffer[
1073
+ layer_id - self.start_layer, h_index : h_index + self.page_size
1074
+ ],
1041
1075
  non_blocking=True,
1042
1076
  )