sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,11 @@
1
+ import json
1
2
  import logging
3
+ import time
4
+ from collections import defaultdict
2
5
  from http import HTTPStatus
3
- from typing import Optional
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import torch
4
9
 
5
10
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
11
 
@@ -8,7 +8,10 @@ import torch
8
8
 
9
9
  from sglang.srt.managers.cache_controller import HiCacheController
10
10
  from sglang.srt.mem_cache.memory_pool import (
11
+ MHATokenToKVPool,
11
12
  MHATokenToKVPoolHost,
13
+ MLATokenToKVPool,
14
+ MLATokenToKVPoolHost,
12
15
  ReqToTokenPool,
13
16
  TokenToKVPoolAllocator,
14
17
  )
@@ -26,14 +29,24 @@ class HiRadixCache(RadixCache):
26
29
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
30
  tp_cache_group: torch.distributed.ProcessGroup,
28
31
  page_size: int,
32
+ hicache_ratio: float,
29
33
  ):
30
34
  if page_size != 1:
31
35
  raise ValueError(
32
36
  "Page size larger than 1 is not yet supported in HiRadixCache."
33
37
  )
34
- self.token_to_kv_pool_host = MHATokenToKVPoolHost(
35
- token_to_kv_pool_allocator.get_kvcache()
36
- )
38
+ self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
39
+ if isinstance(self.kv_cache, MHATokenToKVPool):
40
+ self.token_to_kv_pool_host = MHATokenToKVPoolHost(
41
+ self.kv_cache, hicache_ratio
42
+ )
43
+ elif isinstance(self.kv_cache, MLATokenToKVPool):
44
+ self.token_to_kv_pool_host = MLATokenToKVPoolHost(
45
+ self.kv_cache, hicache_ratio
46
+ )
47
+ else:
48
+ raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
49
+
37
50
  self.tp_group = tp_cache_group
38
51
  self.page_size = page_size
39
52
 
@@ -295,9 +308,9 @@ class HiRadixCache(RadixCache):
295
308
 
296
309
  value, last_node = self._match_prefix_helper(self.root_node, key)
297
310
  if value:
298
- value = torch.concat(value)
311
+ value = torch.cat(value)
299
312
  else:
300
- value = torch.tensor([], dtype=torch.int32)
313
+ value = torch.tensor([], dtype=torch.int64)
301
314
 
302
315
  last_node_global = last_node
303
316
  while last_node.evicted:
@@ -317,13 +330,11 @@ class HiRadixCache(RadixCache):
317
330
  prefix_len = _key_match(child.key, key)
318
331
  if prefix_len < len(child.key):
319
332
  new_node = self._split_node(child.key, child, prefix_len)
320
- self.inc_hit_count(new_node)
321
333
  if not new_node.evicted:
322
334
  value.append(new_node.value)
323
335
  node = new_node
324
336
  break
325
337
  else:
326
- self.inc_hit_count(child)
327
338
  if not child.evicted:
328
339
  value.append(child.value)
329
340
  node = child
@@ -19,7 +19,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
19
19
  Memory pool.
20
20
 
21
21
  SGLang has two levels of memory pool.
22
- ReqToTokenPool maps a a request to its token locations.
22
+ ReqToTokenPool maps a request to its token locations.
23
23
  TokenToKVPoolAllocator manages the indices to kv cache data.
24
24
  KVCache actually holds the physical kv cache.
25
25
  """
@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
115
115
  ) -> None:
116
116
  raise NotImplementedError()
117
117
 
118
+ @abc.abstractmethod
119
+ def get_flat_data(self, indices):
120
+ raise NotImplementedError()
121
+
122
+ @abc.abstractmethod
123
+ def transfer(self, indices, flat_data):
124
+ raise NotImplementedError()
125
+
126
+ @abc.abstractmethod
127
+ def transfer_per_layer(self, indices, flat_data, layer_id):
128
+ raise NotImplementedError()
129
+
130
+ def register_layer_transfer_counter(self, layer_transfer_counter):
131
+ self.layer_transfer_counter = layer_transfer_counter
132
+
118
133
 
119
134
  class TokenToKVPoolAllocator:
120
135
  """An allocator managing the indices to kv cache data."""
@@ -157,7 +172,7 @@ class TokenToKVPoolAllocator:
157
172
  return
158
173
 
159
174
  if self.is_not_in_free_group:
160
- self.free_slots = torch.concat((self.free_slots, free_index))
175
+ self.free_slots = torch.cat((self.free_slots, free_index))
161
176
  else:
162
177
  self.free_group.append(free_index)
163
178
 
@@ -168,14 +183,14 @@ class TokenToKVPoolAllocator:
168
183
  def free_group_end(self):
169
184
  self.is_not_in_free_group = True
170
185
  if self.free_group:
171
- self.free(torch.concat(self.free_group))
186
+ self.free(torch.cat(self.free_group))
172
187
 
173
188
  def clear(self):
174
189
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
175
190
  self.free_slots = torch.arange(
176
191
  1, self.size + 1, dtype=torch.int64, device=self.device
177
192
  )
178
- self.is_in_free_group = False
193
+ self.is_not_in_free_group = True
179
194
  self.free_group = []
180
195
 
181
196
 
@@ -212,7 +227,8 @@ class MHATokenToKVPool(KVCache):
212
227
 
213
228
  self.layer_transfer_counter = None
214
229
  self.capture_mode = False
215
- self.alt_stream = torch.cuda.Stream()
230
+ self.device_module = torch.get_device_module(self.device)
231
+ self.alt_stream = self.device_module.Stream()
216
232
 
217
233
  k_size, v_size = self.get_kv_size_bytes()
218
234
  logger.info(
@@ -255,6 +271,19 @@ class MHATokenToKVPool(KVCache):
255
271
  v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
256
272
  return k_size_bytes, v_size_bytes
257
273
 
274
+ # for disagg
275
+ def get_contiguous_buf_infos(self):
276
+ kv_data_ptrs = [
277
+ self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
278
+ ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
279
+ kv_data_lens = [
280
+ self.get_key_buffer(i).nbytes for i in range(self.layer_num)
281
+ ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
282
+ kv_item_lens = [
283
+ self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
284
+ ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
285
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
286
+
258
287
  # Todo: different memory layout
259
288
  def get_flat_data(self, indices):
260
289
  # prepare a large chunk of contiguous data for efficient transfer
@@ -275,9 +304,6 @@ class MHATokenToKVPool(KVCache):
275
304
  self.k_buffer[i][indices] = k_data[i]
276
305
  self.v_buffer[i][indices] = v_data[i]
277
306
 
278
- def register_layer_transfer_counter(self, layer_transfer_counter):
279
- self.layer_transfer_counter = layer_transfer_counter
280
-
281
307
  def transfer_per_layer(self, indices, flat_data, layer_id):
282
308
  # transfer prepared data from host to device
283
309
  flat_data = flat_data.to(device=self.device, non_blocking=False)
@@ -327,11 +353,13 @@ class MHATokenToKVPool(KVCache):
327
353
  cache_v = cache_v.view(self.store_dtype)
328
354
 
329
355
  if self.capture_mode and cache_k.shape[0] < 4:
330
- self.alt_stream.wait_stream(torch.cuda.current_stream())
331
- with torch.cuda.stream(self.alt_stream):
356
+ # Overlap the copy of K and V cache for small batch size
357
+ current_stream = self.device_module.current_stream()
358
+ self.alt_stream.wait_stream(current_stream)
359
+ with self.device_module.stream(self.alt_stream):
332
360
  self.k_buffer[layer_id][loc] = cache_k
333
361
  self.v_buffer[layer_id][loc] = cache_v
334
- torch.cuda.current_stream().wait_stream(self.alt_stream)
362
+ current_stream.wait_stream(self.alt_stream)
335
363
  else:
336
364
  self.k_buffer[layer_id][loc] = cache_k
337
365
  self.v_buffer[layer_id][loc] = cache_v
@@ -388,6 +416,8 @@ class MLATokenToKVPool(KVCache):
388
416
  else:
389
417
  self.store_dtype = dtype
390
418
  self.kv_lora_rank = kv_lora_rank
419
+ self.qk_rope_head_dim = qk_rope_head_dim
420
+ self.layer_num = layer_num
391
421
 
392
422
  memory_saver_adapter = TorchMemorySaverAdapter.create(
393
423
  enable=enable_memory_saver
@@ -404,12 +434,20 @@ class MLATokenToKVPool(KVCache):
404
434
  for _ in range(layer_num)
405
435
  ]
406
436
 
437
+ self.layer_transfer_counter = None
438
+
407
439
  def get_key_buffer(self, layer_id: int):
440
+ if self.layer_transfer_counter is not None:
441
+ self.layer_transfer_counter.wait_until(layer_id)
442
+
408
443
  if self.store_dtype != self.dtype:
409
444
  return self.kv_buffer[layer_id].view(self.dtype)
410
445
  return self.kv_buffer[layer_id]
411
446
 
412
447
  def get_value_buffer(self, layer_id: int):
448
+ if self.layer_transfer_counter is not None:
449
+ self.layer_transfer_counter.wait_until(layer_id)
450
+
413
451
  if self.store_dtype != self.dtype:
414
452
  return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
415
453
  return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
@@ -432,6 +470,22 @@ class MLATokenToKVPool(KVCache):
432
470
  else:
433
471
  self.kv_buffer[layer_id][loc] = cache_k
434
472
 
473
+ def get_flat_data(self, indices):
474
+ # prepare a large chunk of contiguous data for efficient transfer
475
+ return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
476
+
477
+ @debug_timing
478
+ def transfer(self, indices, flat_data):
479
+ # transfer prepared data from host to device
480
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
481
+ for i in range(self.layer_num):
482
+ self.kv_buffer[i][indices] = flat_data[i]
483
+
484
+ def transfer_per_layer(self, indices, flat_data, layer_id):
485
+ # transfer prepared data from host to device
486
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
487
+ self.kv_buffer[layer_id][indices] = flat_data
488
+
435
489
 
436
490
  class DoubleSparseTokenToKVPool(KVCache):
437
491
  def __init__(
@@ -508,6 +562,15 @@ class DoubleSparseTokenToKVPool(KVCache):
508
562
  self.v_buffer[layer_id][loc] = cache_v
509
563
  self.label_buffer[layer_id][loc] = cache_label
510
564
 
565
+ def get_flat_data(self, indices):
566
+ pass
567
+
568
+ def transfer(self, indices, flat_data):
569
+ pass
570
+
571
+ def transfer_per_layer(self, indices, flat_data, layer_id):
572
+ pass
573
+
511
574
 
512
575
  class MemoryStateInt(IntEnum):
513
576
  IDLE = 0
@@ -517,21 +580,28 @@ class MemoryStateInt(IntEnum):
517
580
  BACKUP = 4
518
581
 
519
582
 
520
- def synchronized(func):
521
- @wraps(func)
522
- def wrapper(self, *args, **kwargs):
523
- with self.lock:
524
- return func(self, *args, **kwargs)
583
+ def synchronized(debug_only=False):
584
+ def _decorator(func):
585
+ @wraps(func)
586
+ def wrapper(self, *args, **kwargs):
587
+ if (not debug_only) or self.debug:
588
+ return func(self, *args, **kwargs)
589
+ with self.lock:
590
+ return func(self, *args, **kwargs)
591
+ else:
592
+ return True
525
593
 
526
- return wrapper
594
+ return wrapper
527
595
 
596
+ return _decorator
528
597
 
529
- class MHATokenToKVPoolHost:
598
+
599
+ class HostKVCache(abc.ABC):
530
600
 
531
601
  def __init__(
532
602
  self,
533
603
  device_pool: MHATokenToKVPool,
534
- host_to_device_ratio: float = 3.0,
604
+ host_to_device_ratio: float,
535
605
  pin_memory: bool = False, # no need to use pin memory with the double buffering
536
606
  device: str = "cpu",
537
607
  ):
@@ -547,12 +617,7 @@ class MHATokenToKVPoolHost:
547
617
 
548
618
  self.size = int(device_pool.size * host_to_device_ratio)
549
619
  self.dtype = device_pool.store_dtype
550
- self.head_num = device_pool.head_num
551
- self.head_dim = device_pool.head_dim
552
- self.layer_num = device_pool.layer_num
553
- self.size_per_token = (
554
- self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
555
- )
620
+ self.size_per_token = self.get_size_per_token()
556
621
 
557
622
  # Verify there is enough available host memory.
558
623
  host_mem = psutil.virtual_memory()
@@ -571,126 +636,218 @@ class MHATokenToKVPoolHost:
571
636
  f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
572
637
  )
573
638
 
574
- self.kv_buffer = torch.zeros(
575
- (2, self.layer_num, self.size, self.head_num, self.head_dim),
576
- dtype=self.dtype,
577
- device=self.device,
578
- pin_memory=self.pin_memory,
579
- )
580
-
581
- # Initialize memory states and tracking structures.
582
- self.mem_state = torch.zeros(
583
- (self.size,), dtype=torch.uint8, device=self.device
584
- )
585
- self.free_slots = torch.arange(self.size, dtype=torch.int32)
586
- self.can_use_mem_size = self.size
639
+ self.kv_buffer = self.init_kv_buffer()
587
640
 
588
641
  # A lock for synchronized operations on memory allocation and state transitions.
589
642
  self.lock = threading.RLock()
643
+ self.debug = logger.isEnabledFor(logging.DEBUG)
644
+ self.clear()
590
645
 
646
+ @abc.abstractmethod
647
+ def get_size_per_token(self):
648
+ raise NotImplementedError()
649
+
650
+ @abc.abstractmethod
651
+ def init_kv_buffer(self):
652
+ raise NotImplementedError()
653
+
654
+ @abc.abstractmethod
655
+ def transfer(self, indices, flat_data):
656
+ raise NotImplementedError()
657
+
658
+ @abc.abstractmethod
591
659
  def get_flat_data(self, indices):
592
- return self.kv_buffer[:, :, indices]
660
+ raise NotImplementedError()
593
661
 
662
+ @abc.abstractmethod
594
663
  def get_flat_data_by_layer(self, indices, layer_id):
595
- return self.kv_buffer[:, layer_id, indices]
664
+ raise NotImplementedError()
596
665
 
666
+ @abc.abstractmethod
597
667
  def assign_flat_data(self, indices, flat_data):
598
- self.kv_buffer[:, :, indices] = flat_data
599
-
600
- @debug_timing
601
- def transfer(self, indices, flat_data):
602
- # backup prepared data from device to host
603
- self.kv_buffer[:, :, indices] = flat_data.to(
604
- device=self.device, non_blocking=False
605
- )
668
+ raise NotImplementedError()
606
669
 
607
- @synchronized
670
+ @synchronized()
608
671
  def clear(self):
609
- self.mem_state.fill_(0)
610
- self.can_use_mem_size = self.size
611
- self.free_slots = torch.arange(self.size, dtype=torch.int32)
672
+ # Initialize memory states and tracking structures.
673
+ self.mem_state = torch.zeros(
674
+ (self.size,), dtype=torch.uint8, device=self.device
675
+ )
676
+ self.free_slots = torch.arange(self.size, dtype=torch.int64)
612
677
 
613
- @synchronized
614
- def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
615
- assert len(indices) > 0, "The indices should not be empty"
616
- states = self.mem_state[indices]
617
- assert (
618
- states == states[0]
619
- ).all(), "The memory slots should have the same state {}".format(states)
620
- return MemoryStateInt(states[0].item())
678
+ def available_size(self):
679
+ return len(self.free_slots)
621
680
 
622
- @synchronized
681
+ @synchronized()
623
682
  def alloc(self, need_size: int) -> torch.Tensor:
624
- if need_size > self.can_use_mem_size:
683
+ if need_size > self.available_size():
625
684
  return None
626
685
 
627
- # todo: de-fragementation
628
686
  select_index = self.free_slots[:need_size]
629
687
  self.free_slots = self.free_slots[need_size:]
630
688
 
631
- self.mem_state[select_index] = MemoryStateInt.RESERVED
632
- self.can_use_mem_size -= need_size
689
+ if self.debug:
690
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
633
691
 
634
692
  return select_index
635
693
 
636
- @synchronized
694
+ @synchronized()
695
+ def free(self, indices: torch.Tensor) -> int:
696
+ self.free_slots = torch.cat([self.free_slots, indices])
697
+ if self.debug:
698
+ self.mem_state[indices] = MemoryStateInt.IDLE
699
+ return len(indices)
700
+
701
+ @synchronized(debug_only=True)
702
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
703
+ assert len(indices) > 0, "The indices should not be empty"
704
+ states = self.mem_state[indices]
705
+ assert (
706
+ states == states[0]
707
+ ).all(), "The memory slots should have the same state {}".format(states)
708
+ return MemoryStateInt(states[0].item())
709
+
710
+ @synchronized(debug_only=True)
637
711
  def is_reserved(self, indices: torch.Tensor) -> bool:
638
712
  return self.get_state(indices) == MemoryStateInt.RESERVED
639
713
 
640
- @synchronized
714
+ @synchronized(debug_only=True)
641
715
  def is_protected(self, indices: torch.Tensor) -> bool:
642
716
  return self.get_state(indices) == MemoryStateInt.PROTECTED
643
717
 
644
- @synchronized
718
+ @synchronized(debug_only=True)
645
719
  def is_synced(self, indices: torch.Tensor) -> bool:
646
720
  return self.get_state(indices) == MemoryStateInt.SYNCED
647
721
 
648
- @synchronized
722
+ @synchronized(debug_only=True)
649
723
  def is_backup(self, indices: torch.Tensor) -> bool:
650
724
  return self.get_state(indices) == MemoryStateInt.BACKUP
651
725
 
652
- @synchronized
726
+ @synchronized(debug_only=True)
653
727
  def update_backup(self, indices: torch.Tensor):
654
- assert self.is_synced(indices), (
655
- f"The host memory slots should be in SYNCED state before turning into BACKUP. "
656
- f"Current state: {self.get_state(indices)}"
657
- )
728
+ if not self.is_synced(indices):
729
+ raise ValueError(
730
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
731
+ f"Current state: {self.get_state(indices)}"
732
+ )
658
733
  self.mem_state[indices] = MemoryStateInt.BACKUP
659
734
 
660
- @synchronized
735
+ @synchronized(debug_only=True)
661
736
  def update_synced(self, indices: torch.Tensor):
662
737
  self.mem_state[indices] = MemoryStateInt.SYNCED
663
738
 
664
- @synchronized
739
+ @synchronized(debug_only=True)
665
740
  def protect_write(self, indices: torch.Tensor):
666
- assert self.is_reserved(indices), (
667
- f"The host memory slots should be RESERVED before write operations. "
668
- f"Current state: {self.get_state(indices)}"
669
- )
741
+ if not self.is_reserved(indices):
742
+ raise ValueError(
743
+ f"The host memory slots should be RESERVED before write operations. "
744
+ f"Current state: {self.get_state(indices)}"
745
+ )
670
746
  self.mem_state[indices] = MemoryStateInt.PROTECTED
671
747
 
672
- @synchronized
748
+ @synchronized(debug_only=True)
673
749
  def protect_load(self, indices: torch.Tensor):
674
- assert self.is_backup(indices), (
675
- f"The host memory slots should be in BACKUP state before load operations. "
676
- f"Current state: {self.get_state(indices)}"
677
- )
750
+ if not self.is_backup(indices):
751
+ raise ValueError(
752
+ f"The host memory slots should be in BACKUP state before load operations. "
753
+ f"Current state: {self.get_state(indices)}"
754
+ )
678
755
  self.mem_state[indices] = MemoryStateInt.PROTECTED
679
756
 
680
- @synchronized
757
+ @synchronized(debug_only=True)
681
758
  def complete_io(self, indices: torch.Tensor):
682
- assert self.is_protected(indices), (
683
- f"The host memory slots should be PROTECTED during I/O operations. "
684
- f"Current state: {self.get_state(indices)}"
685
- )
759
+ if not self.is_protected(indices):
760
+ raise ValueError(
761
+ f"The host memory slots should be PROTECTED during I/O operations. "
762
+ f"Current state: {self.get_state(indices)}"
763
+ )
686
764
  self.mem_state[indices] = MemoryStateInt.SYNCED
687
765
 
688
- def available_size(self):
689
- return len(self.free_slots)
690
766
 
691
- @synchronized
692
- def free(self, indices: torch.Tensor) -> int:
693
- self.mem_state[indices] = MemoryStateInt.IDLE
694
- self.free_slots = torch.concat([self.free_slots, indices])
695
- self.can_use_mem_size += len(indices)
696
- return len(indices)
767
+ class MHATokenToKVPoolHost(HostKVCache):
768
+ def __init__(
769
+ self,
770
+ device_pool: MHATokenToKVPool,
771
+ host_to_device_ratio: float,
772
+ pin_memory: bool = False, # no need to use pin memory with the double buffering
773
+ device: str = "cpu",
774
+ ):
775
+ super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
776
+
777
+ def get_size_per_token(self):
778
+ self.head_num = self.device_pool.head_num
779
+ self.head_dim = self.device_pool.head_dim
780
+ self.layer_num = self.device_pool.layer_num
781
+
782
+ return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
783
+
784
+ def init_kv_buffer(self):
785
+ return torch.empty(
786
+ (2, self.layer_num, self.size, self.head_num, self.head_dim),
787
+ dtype=self.dtype,
788
+ device=self.device,
789
+ pin_memory=self.pin_memory,
790
+ )
791
+
792
+ @debug_timing
793
+ def transfer(self, indices, flat_data):
794
+ # backup prepared data from device to host
795
+ self.kv_buffer[:, :, indices] = flat_data.to(
796
+ device=self.device, non_blocking=False
797
+ )
798
+
799
+ def get_flat_data(self, indices):
800
+ return self.kv_buffer[:, :, indices]
801
+
802
+ def get_flat_data_by_layer(self, indices, layer_id):
803
+ return self.kv_buffer[:, layer_id, indices]
804
+
805
+ def assign_flat_data(self, indices, flat_data):
806
+ self.kv_buffer[:, :, indices] = flat_data
807
+
808
+
809
+ class MLATokenToKVPoolHost(HostKVCache):
810
+ def __init__(
811
+ self,
812
+ device_pool: MLATokenToKVPool,
813
+ host_to_device_ratio: float,
814
+ pin_memory: bool = False, # no need to use pin memory with the double buffering
815
+ device: str = "cpu",
816
+ ):
817
+ super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
818
+
819
+ def get_size_per_token(self):
820
+ self.kv_lora_rank = self.device_pool.kv_lora_rank
821
+ self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
822
+ self.layer_num = self.device_pool.layer_num
823
+
824
+ return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
825
+
826
+ def init_kv_buffer(self):
827
+ return torch.empty(
828
+ (
829
+ self.layer_num,
830
+ self.size,
831
+ 1,
832
+ self.kv_lora_rank + self.qk_rope_head_dim,
833
+ ),
834
+ dtype=self.dtype,
835
+ device=self.device,
836
+ pin_memory=self.pin_memory,
837
+ )
838
+
839
+ @debug_timing
840
+ def transfer(self, indices, flat_data):
841
+ # backup prepared data from device to host
842
+ self.kv_buffer[:, indices] = flat_data.to(
843
+ device=self.device, non_blocking=False
844
+ )
845
+
846
+ def get_flat_data(self, indices):
847
+ return self.kv_buffer[:, indices]
848
+
849
+ def get_flat_data_by_layer(self, indices, layer_id):
850
+ return self.kv_buffer[layer_id, indices]
851
+
852
+ def assign_flat_data(self, indices, flat_data):
853
+ self.kv_buffer[:, indices] = flat_data
@@ -272,12 +272,12 @@ class PagedTokenToKVPoolAllocator:
272
272
  def free_group_end(self):
273
273
  self.is_not_in_free_group = True
274
274
  if self.free_group:
275
- self.free(torch.concat(self.free_group))
275
+ self.free(torch.cat(self.free_group))
276
276
 
277
277
  def clear(self):
278
278
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
279
279
  self.free_pages = torch.arange(
280
280
  1, self.num_pages + 1, dtype=torch.int64, device=self.device
281
281
  )
282
- self.is_in_free_group = False
282
+ self.is_not_in_free_group = True
283
283
  self.free_group = []
@@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache):
140
140
  return (
141
141
  torch.empty(
142
142
  (0,),
143
- dtype=torch.int32,
143
+ dtype=torch.int64,
144
144
  device=self.device,
145
145
  ),
146
146
  self.root_node,
@@ -152,9 +152,9 @@ class RadixCache(BasePrefixCache):
152
152
 
153
153
  value, last_node = self._match_prefix_helper(self.root_node, key)
154
154
  if value:
155
- value = torch.concat(value)
155
+ value = torch.cat(value)
156
156
  else:
157
- value = torch.empty((0,), dtype=torch.int32, device=self.device)
157
+ value = torch.empty((0,), dtype=torch.int64, device=self.device)
158
158
  return value, last_node
159
159
 
160
160
  def insert(self, key: List, value=None):
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
317
317
  _dfs_helper(child)
318
318
 
319
319
  _dfs_helper(self.root_node)
320
- return torch.concat(values)
320
+ return torch.cat(values)
321
321
 
322
322
  ##### Internal Helper Functions #####
323
323