sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -26,24 +26,17 @@ KVCache actually holds the physical kv cache.
26
26
 
27
27
  import abc
28
28
  import logging
29
- import threading
30
- from enum import IntEnum
31
- from functools import wraps
29
+ from contextlib import nullcontext
32
30
  from typing import List, Optional, Tuple, Union
33
31
 
34
32
  import numpy as np
35
- import psutil
36
33
  import torch
37
34
  import triton
38
35
  import triton.language as tl
39
36
 
37
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
40
38
  from sglang.srt.layers.radix_attention import RadixAttention
41
- from sglang.srt.utils import (
42
- debug_timing,
43
- get_compiler_backend,
44
- is_cuda,
45
- next_power_of_2,
46
- )
39
+ from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
47
40
 
48
41
  logger = logging.getLogger(__name__)
49
42
 
@@ -61,6 +54,7 @@ class ReqToTokenPool:
61
54
  device: str,
62
55
  enable_memory_saver: bool,
63
56
  ):
57
+
64
58
  memory_saver_adapter = TorchMemorySaverAdapter.create(
65
59
  enable=enable_memory_saver
66
60
  )
@@ -68,7 +62,7 @@ class ReqToTokenPool:
68
62
  self.size = size
69
63
  self.max_context_len = max_context_len
70
64
  self.device = device
71
- with memory_saver_adapter.region():
65
+ with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
72
66
  self.req_to_token = torch.zeros(
73
67
  (size, max_context_len), dtype=torch.int32, device=device
74
68
  )
@@ -128,6 +122,9 @@ class KVCache(abc.ABC):
128
122
  enable=enable_memory_saver
129
123
  )
130
124
 
125
+ # used for chunked cpu-offloading
126
+ self.cpu_offloading_chunk_size = 8192
127
+
131
128
  @abc.abstractmethod
132
129
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
133
130
  raise NotImplementedError()
@@ -150,89 +147,23 @@ class KVCache(abc.ABC):
150
147
  ) -> None:
151
148
  raise NotImplementedError()
152
149
 
153
- @abc.abstractmethod
154
150
  def get_flat_data(self, indices):
155
151
  raise NotImplementedError()
156
152
 
157
- @abc.abstractmethod
158
153
  def transfer(self, indices, flat_data):
159
154
  raise NotImplementedError()
160
155
 
161
- @abc.abstractmethod
162
156
  def transfer_per_layer(self, indices, flat_data, layer_id):
163
157
  raise NotImplementedError()
164
158
 
165
159
  def register_layer_transfer_counter(self, layer_transfer_counter):
166
160
  self.layer_transfer_counter = layer_transfer_counter
167
161
 
162
+ def get_cpu_copy(self, indices):
163
+ raise NotImplementedError()
168
164
 
169
- class TokenToKVPoolAllocator:
170
- """An allocator managing the indices to kv cache data."""
171
-
172
- def __init__(
173
- self,
174
- size: int,
175
- dtype: torch.dtype,
176
- device: str,
177
- kvcache: KVCache,
178
- ):
179
- self.size = size
180
- self.dtype = dtype
181
- self.device = device
182
- self.page_size = 1
183
-
184
- self.free_slots = None
185
- self.is_not_in_free_group = True
186
- self.free_group = []
187
- self.clear()
188
-
189
- self._kvcache = kvcache
190
-
191
- def available_size(self):
192
- return len(self.free_slots)
193
-
194
- def get_kvcache(self):
195
- return self._kvcache
196
-
197
- def alloc(self, need_size: int):
198
- if need_size > len(self.free_slots):
199
- return None
200
-
201
- select_index = self.free_slots[:need_size]
202
- self.free_slots = self.free_slots[need_size:]
203
- return select_index
204
-
205
- def free(self, free_index: torch.Tensor):
206
- if free_index.numel() == 0:
207
- return
208
-
209
- if self.is_not_in_free_group:
210
- self.free_slots = torch.cat((self.free_slots, free_index))
211
- else:
212
- self.free_group.append(free_index)
213
-
214
- def free_group_begin(self):
215
- self.is_not_in_free_group = False
216
- self.free_group = []
217
-
218
- def free_group_end(self):
219
- self.is_not_in_free_group = True
220
- if self.free_group:
221
- self.free(torch.cat(self.free_group))
222
-
223
- def backup_state(self):
224
- return self.free_slots
225
-
226
- def restore_state(self, free_slots):
227
- self.free_slots = free_slots
228
-
229
- def clear(self):
230
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
231
- self.free_slots = torch.arange(
232
- 1, self.size + 1, dtype=torch.int64, device=self.device
233
- )
234
- self.is_not_in_free_group = True
235
- self.free_group = []
165
+ def load_cpu_copy(self, kv_cache_cpu, indices):
166
+ raise NotImplementedError()
236
167
 
237
168
 
238
169
  class MHATokenToKVPool(KVCache):
@@ -263,11 +194,25 @@ class MHATokenToKVPool(KVCache):
263
194
 
264
195
  self.head_num = head_num
265
196
  self.head_dim = head_dim
197
+
198
+ # for disagg with nvlink
199
+ self.enable_custom_mem_pool = get_bool_env_var(
200
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
201
+ )
202
+ if self.enable_custom_mem_pool:
203
+ # TODO(shangming): abstract custom allocator class for more backends
204
+ from mooncake.allocator import NVLinkAllocator
205
+
206
+ allocator = NVLinkAllocator.get_allocator(self.device)
207
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
208
+ else:
209
+ self.custom_mem_pool = None
210
+
266
211
  self._create_buffers()
267
212
 
268
213
  self.layer_transfer_counter = None
269
214
  self.device_module = torch.get_device_module(self.device)
270
- self.alt_stream = self.device_module.Stream() if is_cuda else None
215
+ self.alt_stream = self.device_module.Stream() if _is_cuda else None
271
216
 
272
217
  k_size, v_size = self.get_kv_size_bytes()
273
218
  logger.info(
@@ -275,25 +220,43 @@ class MHATokenToKVPool(KVCache):
275
220
  )
276
221
 
277
222
  def _create_buffers(self):
278
- with self.memory_saver_adapter.region():
279
- # [size, head_num, head_dim] for each layer
280
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
281
- self.k_buffer = [
282
- torch.zeros(
283
- (self.size + self.page_size, self.head_num, self.head_dim),
284
- dtype=self.store_dtype,
285
- device=self.device,
286
- )
287
- for _ in range(self.layer_num)
288
- ]
289
- self.v_buffer = [
290
- torch.zeros(
291
- (self.size + self.page_size, self.head_num, self.head_dim),
292
- dtype=self.store_dtype,
293
- device=self.device,
294
- )
295
- for _ in range(self.layer_num)
296
- ]
223
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
224
+ with (
225
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
226
+ if self.enable_custom_mem_pool
227
+ else nullcontext()
228
+ ):
229
+ # [size, head_num, head_dim] for each layer
230
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
231
+ self.k_buffer = [
232
+ torch.zeros(
233
+ (self.size + self.page_size, self.head_num, self.head_dim),
234
+ dtype=self.store_dtype,
235
+ device=self.device,
236
+ )
237
+ for _ in range(self.layer_num)
238
+ ]
239
+ self.v_buffer = [
240
+ torch.zeros(
241
+ (self.size + self.page_size, self.head_num, self.head_dim),
242
+ dtype=self.store_dtype,
243
+ device=self.device,
244
+ )
245
+ for _ in range(self.layer_num)
246
+ ]
247
+
248
+ self.data_ptrs = torch.tensor(
249
+ [x.data_ptr() for x in self.k_buffer + self.v_buffer],
250
+ dtype=torch.uint64,
251
+ device=self.device,
252
+ )
253
+ self.data_strides = torch.tensor(
254
+ [
255
+ np.prod(x.shape[1:]) * x.dtype.itemsize
256
+ for x in self.k_buffer + self.v_buffer
257
+ ],
258
+ device=self.device,
259
+ )
297
260
 
298
261
  def _clear_buffers(self):
299
262
  del self.k_buffer
@@ -315,20 +278,66 @@ class MHATokenToKVPool(KVCache):
315
278
  # layer_num x [seq_len, head_num, head_dim]
316
279
  # layer_num x [page_num, page_size, head_num, head_dim]
317
280
  kv_data_ptrs = [
318
- self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
319
- ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
281
+ self.get_key_buffer(i).data_ptr()
282
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
283
+ ] + [
284
+ self.get_value_buffer(i).data_ptr()
285
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
286
+ ]
320
287
  kv_data_lens = [
321
- self.get_key_buffer(i).nbytes for i in range(self.layer_num)
322
- ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
288
+ self.get_key_buffer(i).nbytes
289
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
290
+ ] + [
291
+ self.get_value_buffer(i).nbytes
292
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
293
+ ]
323
294
  kv_item_lens = [
324
295
  self.get_key_buffer(i)[0].nbytes * self.page_size
325
- for i in range(self.layer_num)
296
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
326
297
  ] + [
327
298
  self.get_value_buffer(i)[0].nbytes * self.page_size
328
- for i in range(self.layer_num)
299
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
329
300
  ]
330
301
  return kv_data_ptrs, kv_data_lens, kv_item_lens
331
302
 
303
+ def maybe_get_custom_mem_pool(self):
304
+ return self.custom_mem_pool
305
+
306
+ def get_cpu_copy(self, indices):
307
+ torch.cuda.synchronize()
308
+ kv_cache_cpu = []
309
+ chunk_size = self.cpu_offloading_chunk_size
310
+ for layer_id in range(self.layer_num):
311
+ kv_cache_cpu.append([])
312
+ for i in range(0, len(indices), chunk_size):
313
+ chunk_indices = indices[i : i + chunk_size]
314
+ k_cpu = self.k_buffer[layer_id][chunk_indices].to(
315
+ "cpu", non_blocking=True
316
+ )
317
+ v_cpu = self.v_buffer[layer_id][chunk_indices].to(
318
+ "cpu", non_blocking=True
319
+ )
320
+ kv_cache_cpu[-1].append([k_cpu, v_cpu])
321
+ torch.cuda.synchronize()
322
+ return kv_cache_cpu
323
+
324
+ def load_cpu_copy(self, kv_cache_cpu, indices):
325
+ torch.cuda.synchronize()
326
+ chunk_size = self.cpu_offloading_chunk_size
327
+ for layer_id in range(self.layer_num):
328
+ for i in range(0, len(indices), chunk_size):
329
+ chunk_indices = indices[i : i + chunk_size]
330
+ k_cpu, v_cpu = (
331
+ kv_cache_cpu[layer_id][i // chunk_size][0],
332
+ kv_cache_cpu[layer_id][i // chunk_size][1],
333
+ )
334
+ assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
335
+ k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
336
+ v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
337
+ self.k_buffer[layer_id][chunk_indices] = k_chunk
338
+ self.v_buffer[layer_id][chunk_indices] = v_chunk
339
+ torch.cuda.synchronize()
340
+
332
341
  # Todo: different memory layout
333
342
  def get_flat_data(self, indices):
334
343
  # prepare a large chunk of contiguous data for efficient transfer
@@ -411,35 +420,15 @@ class MHATokenToKVPool(KVCache):
411
420
  self.k_buffer[layer_id - self.start_layer][loc] = cache_k
412
421
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
413
422
 
414
-
415
- @torch.compile
416
- def fused_downcast(
417
- cache_k: torch.Tensor,
418
- cache_v: torch.Tensor,
419
- k_scale: torch.Tensor,
420
- v_scale: torch.Tensor,
421
- dtype: torch.dtype,
422
- store_dtype: torch.dtype,
423
- max_fp8: float,
424
- min_fp8: float,
425
- ):
426
- cache_k = cache_k / k_scale
427
- cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
428
- cache_v = cache_v / v_scale
429
- cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
430
- cache_k = cache_k.to(dtype)
431
- cache_v = cache_v.to(dtype)
432
- cache_k = cache_k.view(store_dtype)
433
- cache_v = cache_v.view(store_dtype)
434
- return cache_k, cache_v
435
-
436
-
437
- # This compiled version is slower in the unit test
438
- # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
439
- @torch.compile(dynamic=True, backend=get_compiler_backend())
440
- def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
441
- dst_1[loc] = src_1.to(dtype).view(store_dtype)
442
- dst_2[loc] = src_2.to(dtype).view(store_dtype)
423
+ def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
424
+ copy_all_layer_kv_cache[(len(self.data_ptrs),)](
425
+ self.data_ptrs,
426
+ self.data_strides,
427
+ tgt_loc,
428
+ src_loc,
429
+ len(tgt_loc),
430
+ next_power_of_2(len(tgt_loc)),
431
+ )
443
432
 
444
433
 
445
434
  @triton.jit
@@ -536,16 +525,34 @@ class MLATokenToKVPool(KVCache):
536
525
  self.kv_lora_rank = kv_lora_rank
537
526
  self.qk_rope_head_dim = qk_rope_head_dim
538
527
 
539
- with self.memory_saver_adapter.region():
540
- # The padded slot 0 is used for writing dummy outputs from padded tokens.
541
- self.kv_buffer = [
542
- torch.zeros(
543
- (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
544
- dtype=self.store_dtype,
545
- device=device,
546
- )
547
- for _ in range(layer_num)
548
- ]
528
+ # for disagg with nvlink
529
+ self.enable_custom_mem_pool = get_bool_env_var(
530
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
531
+ )
532
+ if self.enable_custom_mem_pool:
533
+ # TODO(shangming): abstract custom allocator class for more backends
534
+ from mooncake.allocator import NVLinkAllocator
535
+
536
+ allocator = NVLinkAllocator.get_allocator(self.device)
537
+ self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
538
+ else:
539
+ self.custom_mem_pool = None
540
+
541
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
542
+ with (
543
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
544
+ if self.custom_mem_pool
545
+ else nullcontext()
546
+ ):
547
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
548
+ self.kv_buffer = [
549
+ torch.zeros(
550
+ (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
551
+ dtype=self.store_dtype,
552
+ device=device,
553
+ )
554
+ for _ in range(layer_num)
555
+ ]
549
556
 
550
557
  self.layer_transfer_counter = None
551
558
 
@@ -571,6 +578,9 @@ class MLATokenToKVPool(KVCache):
571
578
  ]
572
579
  return kv_data_ptrs, kv_data_lens, kv_item_lens
573
580
 
581
+ def maybe_get_custom_mem_pool(self):
582
+ return self.custom_mem_pool
583
+
574
584
  def get_key_buffer(self, layer_id: int):
575
585
  if self.layer_transfer_counter is not None:
576
586
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
@@ -644,6 +654,33 @@ class MLATokenToKVPool(KVCache):
644
654
  flat_data = flat_data.to(device=self.device, non_blocking=False)
645
655
  self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
646
656
 
657
+ def get_cpu_copy(self, indices):
658
+ torch.cuda.synchronize()
659
+ kv_cache_cpu = []
660
+ chunk_size = self.cpu_offloading_chunk_size
661
+ for layer_id in range(self.layer_num):
662
+ kv_cache_cpu.append([])
663
+ for i in range(0, len(indices), chunk_size):
664
+ chunk_indices = indices[i : i + chunk_size]
665
+ kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
666
+ "cpu", non_blocking=True
667
+ )
668
+ kv_cache_cpu[-1].append(kv_cpu)
669
+ torch.cuda.synchronize()
670
+ return kv_cache_cpu
671
+
672
+ def load_cpu_copy(self, kv_cache_cpu, indices):
673
+ torch.cuda.synchronize()
674
+ chunk_size = self.cpu_offloading_chunk_size
675
+ for layer_id in range(self.layer_num):
676
+ for i in range(0, len(indices), chunk_size):
677
+ chunk_indices = indices[i : i + chunk_size]
678
+ kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
679
+ assert kv_cpu.shape[0] == len(chunk_indices)
680
+ kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
681
+ self.kv_buffer[layer_id][chunk_indices] = kv_chunk
682
+ torch.cuda.synchronize()
683
+
647
684
 
648
685
  class DoubleSparseTokenToKVPool(KVCache):
649
686
  def __init__(
@@ -671,7 +708,7 @@ class DoubleSparseTokenToKVPool(KVCache):
671
708
  end_layer,
672
709
  )
673
710
 
674
- with self.memory_saver_adapter.region():
711
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
675
712
  # [size, head_num, head_dim] for each layer
676
713
  self.k_buffer = [
677
714
  torch.zeros(
@@ -733,368 +770,39 @@ class DoubleSparseTokenToKVPool(KVCache):
733
770
  pass
734
771
 
735
772
 
736
- class MemoryStateInt(IntEnum):
737
- IDLE = 0
738
- RESERVED = 1
739
- PROTECTED = 2
740
- SYNCED = 3
741
- BACKUP = 4
742
-
743
-
744
- def synchronized(debug_only=False):
745
- def _decorator(func):
746
- @wraps(func)
747
- def wrapper(self, *args, **kwargs):
748
- if (not debug_only) or self.debug:
749
- return func(self, *args, **kwargs)
750
- with self.lock:
751
- return func(self, *args, **kwargs)
752
- else:
753
- return True
754
-
755
- return wrapper
756
-
757
- return _decorator
758
-
759
-
760
- class HostKVCache(abc.ABC):
761
-
762
- def __init__(
763
- self,
764
- device_pool: KVCache,
765
- host_to_device_ratio: float,
766
- host_size: int,
767
- pin_memory: bool,
768
- device: str,
769
- page_size: int,
770
- ):
771
- self.device_pool = device_pool
772
- self.dtype = device_pool.store_dtype
773
- self.pin_memory = pin_memory
774
- self.device = device
775
- self.page_size = page_size
776
- self.size_per_token = self.get_size_per_token()
777
- if host_size > 0:
778
- self.size = int(host_size * 1e9 // self.size_per_token)
779
- else:
780
- self.size = int(device_pool.size * host_to_device_ratio)
781
- # Align the host memory pool size to the page size
782
- self.size = self.size - (self.size % self.page_size)
783
- self.start_layer = device_pool.start_layer
784
- self.end_layer = device_pool.end_layer
785
-
786
- assert (
787
- self.size > device_pool.size
788
- ), "The host memory should be larger than the device memory with the current protocol"
789
-
790
- # Verify there is enough available host memory.
791
- host_mem = psutil.virtual_memory()
792
- requested_bytes = self.size * self.size_per_token
793
- # preserve at least 10GB for other usage
794
- ten_gb = 10 * (1024**3)
795
- if requested_bytes > host_mem.available - ten_gb:
796
- raise ValueError(
797
- f"Not enough host memory available. Requesting "
798
- f"{requested_bytes / 1e9:.2f} GB but only have "
799
- f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
800
- f"size of the hierarchical cache."
801
- )
802
- else:
803
- logger.info(
804
- f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
805
- )
806
-
807
- self.kv_buffer = self.init_kv_buffer()
808
-
809
- # A lock for synchronized operations on memory allocation and state transitions.
810
- self.lock = threading.RLock()
811
- self.debug = logger.isEnabledFor(logging.DEBUG)
812
- self.clear()
813
-
814
- @abc.abstractmethod
815
- def get_size_per_token(self):
816
- raise NotImplementedError()
817
-
818
- @abc.abstractmethod
819
- def init_kv_buffer(self):
820
- raise NotImplementedError()
821
-
822
- @abc.abstractmethod
823
- def transfer(self, indices, flat_data):
824
- raise NotImplementedError()
825
-
826
- @abc.abstractmethod
827
- def get_flat_data(self, indices):
828
- raise NotImplementedError()
829
-
830
- @abc.abstractmethod
831
- def get_flat_data_by_layer(self, indices, layer_id):
832
- raise NotImplementedError()
833
-
834
- @abc.abstractmethod
835
- def assign_flat_data(self, indices, flat_data):
836
- raise NotImplementedError()
837
-
838
- @synchronized()
839
- def clear(self):
840
- # Initialize memory states and tracking structures.
841
- self.mem_state = torch.zeros(
842
- (self.size,), dtype=torch.uint8, device=self.device
843
- )
844
- self.free_slots = torch.arange(self.size, dtype=torch.int64)
845
-
846
- def available_size(self):
847
- return len(self.free_slots)
848
-
849
- @synchronized()
850
- def alloc(self, need_size: int) -> torch.Tensor:
851
- if need_size > self.available_size():
852
- return None
853
-
854
- select_index = self.free_slots[:need_size]
855
- self.free_slots = self.free_slots[need_size:]
856
-
857
- if self.debug:
858
- self.mem_state[select_index] = MemoryStateInt.RESERVED
859
-
860
- return select_index
861
-
862
- @synchronized()
863
- def free(self, indices: torch.Tensor) -> int:
864
- self.free_slots = torch.cat([self.free_slots, indices])
865
- if self.debug:
866
- self.mem_state[indices] = MemoryStateInt.IDLE
867
- return len(indices)
868
-
869
- @synchronized(debug_only=True)
870
- def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
871
- assert len(indices) > 0, "The indices should not be empty"
872
- states = self.mem_state[indices]
873
- assert (
874
- states == states[0]
875
- ).all(), "The memory slots should have the same state {}".format(states)
876
- return MemoryStateInt(states[0].item())
877
-
878
- @synchronized(debug_only=True)
879
- def is_reserved(self, indices: torch.Tensor) -> bool:
880
- return self.get_state(indices) == MemoryStateInt.RESERVED
881
-
882
- @synchronized(debug_only=True)
883
- def is_protected(self, indices: torch.Tensor) -> bool:
884
- return self.get_state(indices) == MemoryStateInt.PROTECTED
885
-
886
- @synchronized(debug_only=True)
887
- def is_synced(self, indices: torch.Tensor) -> bool:
888
- return self.get_state(indices) == MemoryStateInt.SYNCED
889
-
890
- @synchronized(debug_only=True)
891
- def is_backup(self, indices: torch.Tensor) -> bool:
892
- return self.get_state(indices) == MemoryStateInt.BACKUP
893
-
894
- @synchronized(debug_only=True)
895
- def update_backup(self, indices: torch.Tensor):
896
- if not self.is_synced(indices):
897
- raise ValueError(
898
- f"The host memory slots should be in SYNCED state before turning into BACKUP. "
899
- f"Current state: {self.get_state(indices)}"
900
- )
901
- self.mem_state[indices] = MemoryStateInt.BACKUP
902
-
903
- @synchronized(debug_only=True)
904
- def update_synced(self, indices: torch.Tensor):
905
- self.mem_state[indices] = MemoryStateInt.SYNCED
906
-
907
- @synchronized(debug_only=True)
908
- def protect_write(self, indices: torch.Tensor):
909
- if not self.is_reserved(indices):
910
- raise ValueError(
911
- f"The host memory slots should be RESERVED before write operations. "
912
- f"Current state: {self.get_state(indices)}"
913
- )
914
- self.mem_state[indices] = MemoryStateInt.PROTECTED
915
-
916
- @synchronized(debug_only=True)
917
- def protect_load(self, indices: torch.Tensor):
918
- if not self.is_backup(indices):
919
- raise ValueError(
920
- f"The host memory slots should be in BACKUP state before load operations. "
921
- f"Current state: {self.get_state(indices)}"
922
- )
923
- self.mem_state[indices] = MemoryStateInt.PROTECTED
924
-
925
- @synchronized(debug_only=True)
926
- def complete_io(self, indices: torch.Tensor):
927
- if not self.is_protected(indices):
928
- raise ValueError(
929
- f"The host memory slots should be PROTECTED during I/O operations. "
930
- f"Current state: {self.get_state(indices)}"
931
- )
932
- self.mem_state[indices] = MemoryStateInt.SYNCED
933
-
934
-
935
- class MHATokenToKVPoolHost(HostKVCache):
936
- device_pool: MHATokenToKVPool
937
-
938
- def __init__(
939
- self,
940
- device_pool: MHATokenToKVPool,
941
- host_to_device_ratio: float,
942
- host_size: int,
943
- page_size: int,
944
- pin_memory: bool = True,
945
- device: str = "cpu",
946
- ):
947
- super().__init__(
948
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
949
- )
950
-
951
- def get_size_per_token(self):
952
- self.head_num = self.device_pool.head_num
953
- self.head_dim = self.device_pool.head_dim
954
- self.layer_num = self.device_pool.layer_num
955
-
956
- return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
957
-
958
- def init_kv_buffer(self):
959
- return torch.empty(
960
- (2, self.layer_num, self.size, self.head_num, self.head_dim),
961
- dtype=self.dtype,
962
- device=self.device,
963
- pin_memory=self.pin_memory,
964
- )
965
-
966
- @debug_timing
967
- def transfer(self, indices, flat_data):
968
- # backup prepared data from device to host
969
- self.kv_buffer[:, :, indices] = flat_data.to(
970
- device=self.device, non_blocking=False
971
- )
972
-
973
- def get_flat_data(self, indices):
974
- return self.kv_buffer[:, :, indices]
975
-
976
- def get_flat_data_by_layer(self, indices, layer_id):
977
- return self.kv_buffer[:, layer_id - self.start_layer, indices]
978
-
979
- def assign_flat_data(self, indices, flat_data):
980
- self.kv_buffer[:, :, indices] = flat_data
981
-
982
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
983
- device_indices_cpu = device_indices[:: self.page_size].cpu()
984
- for i in range(len(device_indices_cpu)):
985
- h_index = host_indices[i * self.page_size]
986
- d_index = device_indices_cpu[i]
987
- for j in range(self.layer_num):
988
- self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
989
- device_pool.k_buffer[j][d_index : d_index + self.page_size],
990
- non_blocking=True,
991
- )
992
- self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
993
- device_pool.v_buffer[j][d_index : d_index + self.page_size],
994
- non_blocking=True,
995
- )
996
-
997
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
998
- device_indices_cpu = device_indices[:: self.page_size].cpu()
999
- for i in range(len(device_indices_cpu)):
1000
- h_index = host_indices[i * self.page_size]
1001
- d_index = device_indices_cpu[i]
1002
- device_pool.k_buffer[layer_id - self.start_layer][
1003
- d_index : d_index + self.page_size
1004
- ].copy_(
1005
- self.kv_buffer[
1006
- 0, layer_id - self.start_layer, h_index : h_index + self.page_size
1007
- ],
1008
- non_blocking=True,
1009
- )
1010
- device_pool.v_buffer[layer_id - self.start_layer][
1011
- d_index : d_index + self.page_size
1012
- ].copy_(
1013
- self.kv_buffer[
1014
- 1, layer_id - self.start_layer, h_index : h_index + self.page_size
1015
- ],
1016
- non_blocking=True,
1017
- )
773
+ @triton.jit
774
+ def copy_all_layer_kv_cache(
775
+ data_ptrs,
776
+ strides,
777
+ tgt_loc_ptr,
778
+ src_loc_ptr,
779
+ num_locs,
780
+ num_locs_upper: tl.constexpr,
781
+ ):
782
+ BLOCK_SIZE: tl.constexpr = 128
1018
783
 
784
+ bid = tl.program_id(0)
785
+ stride = tl.load(strides + bid)
1019
786
 
1020
- class MLATokenToKVPoolHost(HostKVCache):
1021
- device_pool: MLATokenToKVPool
787
+ data_ptr = tl.load(data_ptrs + bid)
788
+ data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
1022
789
 
1023
- def __init__(
1024
- self,
1025
- device_pool: MLATokenToKVPool,
1026
- host_to_device_ratio: float,
1027
- host_size: int,
1028
- page_size: int,
1029
- pin_memory: bool = True,
1030
- device: str = "cpu",
1031
- ):
1032
- super().__init__(
1033
- device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
1034
- )
790
+ num_locs_offset = tl.arange(0, num_locs_upper)
791
+ tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
792
+ src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1035
793
 
1036
- def get_size_per_token(self):
1037
- self.kv_lora_rank = self.device_pool.kv_lora_rank
1038
- self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
1039
- self.layer_num = self.device_pool.layer_num
794
+ # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
795
+ # because this copy is an inplace operation.
1040
796
 
1041
- return (
1042
- (self.kv_lora_rank + self.qk_rope_head_dim)
1043
- * 1
1044
- * self.dtype.itemsize
1045
- * self.layer_num
797
+ num_loop = tl.cdiv(stride, BLOCK_SIZE)
798
+ for i in range(num_loop):
799
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
800
+ mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
801
+ value = tl.load(
802
+ data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1046
803
  )
1047
-
1048
- def init_kv_buffer(self):
1049
- return torch.empty(
1050
- (
1051
- self.layer_num,
1052
- self.size,
1053
- 1,
1054
- self.kv_lora_rank + self.qk_rope_head_dim,
1055
- ),
1056
- dtype=self.dtype,
1057
- device=self.device,
1058
- pin_memory=self.pin_memory,
1059
- )
1060
-
1061
- @debug_timing
1062
- def transfer(self, indices, flat_data):
1063
- # backup prepared data from device to host
1064
- self.kv_buffer[:, indices] = flat_data.to(
1065
- device=self.device, non_blocking=False
804
+ tl.store(
805
+ data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
806
+ value,
807
+ mask=mask,
1066
808
  )
1067
-
1068
- def get_flat_data(self, indices):
1069
- return self.kv_buffer[:, indices]
1070
-
1071
- def get_flat_data_by_layer(self, indices, layer_id):
1072
- return self.kv_buffer[layer_id - self.start_layer, indices]
1073
-
1074
- def assign_flat_data(self, indices, flat_data):
1075
- self.kv_buffer[:, indices] = flat_data
1076
-
1077
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
1078
- device_indices_cpu = device_indices[:: self.page_size].cpu()
1079
- for i in range(len(device_indices_cpu)):
1080
- h_index = host_indices[i * self.page_size]
1081
- d_index = device_indices_cpu[i]
1082
- for j in range(self.layer_num):
1083
- self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
1084
- device_pool.kv_buffer[j][d_index : d_index + self.page_size],
1085
- non_blocking=True,
1086
- )
1087
-
1088
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
1089
- device_indices_cpu = device_indices[:: self.page_size].cpu()
1090
- for i in range(len(device_indices_cpu)):
1091
- h_index = host_indices[i * self.page_size]
1092
- d_index = device_indices_cpu[i]
1093
- device_pool.kv_buffer[layer_id - self.start_layer][
1094
- d_index : d_index + self.page_size
1095
- ].copy_(
1096
- self.kv_buffer[
1097
- layer_id - self.start_layer, h_index : h_index + self.page_size
1098
- ],
1099
- non_blocking=True,
1100
- )