sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
19
19
  Memory pool.
20
20
 
21
21
  SGLang has two levels of memory pool.
22
- ReqToTokenPool maps a a request to its token locations.
22
+ ReqToTokenPool maps a request to its token locations.
23
23
  TokenToKVPoolAllocator manages the indices to kv cache data.
24
24
  KVCache actually holds the physical kv cache.
25
25
  """
@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
115
115
  ) -> None:
116
116
  raise NotImplementedError()
117
117
 
118
+ @abc.abstractmethod
119
+ def get_flat_data(self, indices):
120
+ raise NotImplementedError()
121
+
122
+ @abc.abstractmethod
123
+ def transfer(self, indices, flat_data):
124
+ raise NotImplementedError()
125
+
126
+ @abc.abstractmethod
127
+ def transfer_per_layer(self, indices, flat_data, layer_id):
128
+ raise NotImplementedError()
129
+
130
+ def register_layer_transfer_counter(self, layer_transfer_counter):
131
+ self.layer_transfer_counter = layer_transfer_counter
132
+
118
133
 
119
134
  class TokenToKVPoolAllocator:
120
135
  """An allocator managing the indices to kv cache data."""
@@ -157,7 +172,7 @@ class TokenToKVPoolAllocator:
157
172
  return
158
173
 
159
174
  if self.is_not_in_free_group:
160
- self.free_slots = torch.concat((self.free_slots, free_index))
175
+ self.free_slots = torch.cat((self.free_slots, free_index))
161
176
  else:
162
177
  self.free_group.append(free_index)
163
178
 
@@ -168,14 +183,14 @@ class TokenToKVPoolAllocator:
168
183
  def free_group_end(self):
169
184
  self.is_not_in_free_group = True
170
185
  if self.free_group:
171
- self.free(torch.concat(self.free_group))
186
+ self.free(torch.cat(self.free_group))
172
187
 
173
188
  def clear(self):
174
189
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
175
190
  self.free_slots = torch.arange(
176
191
  1, self.size + 1, dtype=torch.int64, device=self.device
177
192
  )
178
- self.is_in_free_group = False
193
+ self.is_not_in_free_group = True
179
194
  self.free_group = []
180
195
 
181
196
 
@@ -212,7 +227,8 @@ class MHATokenToKVPool(KVCache):
212
227
 
213
228
  self.layer_transfer_counter = None
214
229
  self.capture_mode = False
215
- self.alt_stream = torch.cuda.Stream()
230
+ self.device_module = torch.get_device_module(self.device)
231
+ self.alt_stream = self.device_module.Stream()
216
232
 
217
233
  k_size, v_size = self.get_kv_size_bytes()
218
234
  logger.info(
@@ -255,6 +271,19 @@ class MHATokenToKVPool(KVCache):
255
271
  v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
256
272
  return k_size_bytes, v_size_bytes
257
273
 
274
+ # for disagg
275
+ def get_contiguous_buf_infos(self):
276
+ kv_data_ptrs = [
277
+ self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
278
+ ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
279
+ kv_data_lens = [
280
+ self.get_key_buffer(i).nbytes for i in range(self.layer_num)
281
+ ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
282
+ kv_item_lens = [
283
+ self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
284
+ ] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
285
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
286
+
258
287
  # Todo: different memory layout
259
288
  def get_flat_data(self, indices):
260
289
  # prepare a large chunk of contiguous data for efficient transfer
@@ -275,9 +304,6 @@ class MHATokenToKVPool(KVCache):
275
304
  self.k_buffer[i][indices] = k_data[i]
276
305
  self.v_buffer[i][indices] = v_data[i]
277
306
 
278
- def register_layer_transfer_counter(self, layer_transfer_counter):
279
- self.layer_transfer_counter = layer_transfer_counter
280
-
281
307
  def transfer_per_layer(self, indices, flat_data, layer_id):
282
308
  # transfer prepared data from host to device
283
309
  flat_data = flat_data.to(device=self.device, non_blocking=False)
@@ -326,12 +352,14 @@ class MHATokenToKVPool(KVCache):
326
352
  cache_k = cache_k.view(self.store_dtype)
327
353
  cache_v = cache_v.view(self.store_dtype)
328
354
 
329
- if self.capture_mode:
330
- self.alt_stream.wait_stream(torch.cuda.current_stream())
331
- with torch.cuda.stream(self.alt_stream):
355
+ if self.capture_mode and cache_k.shape[0] < 4:
356
+ # Overlap the copy of K and V cache for small batch size
357
+ current_stream = self.device_module.current_stream()
358
+ self.alt_stream.wait_stream(current_stream)
359
+ with self.device_module.stream(self.alt_stream):
332
360
  self.k_buffer[layer_id][loc] = cache_k
333
361
  self.v_buffer[layer_id][loc] = cache_v
334
- torch.cuda.current_stream().wait_stream(self.alt_stream)
362
+ current_stream.wait_stream(self.alt_stream)
335
363
  else:
336
364
  self.k_buffer[layer_id][loc] = cache_k
337
365
  self.v_buffer[layer_id][loc] = cache_v
@@ -388,6 +416,8 @@ class MLATokenToKVPool(KVCache):
388
416
  else:
389
417
  self.store_dtype = dtype
390
418
  self.kv_lora_rank = kv_lora_rank
419
+ self.qk_rope_head_dim = qk_rope_head_dim
420
+ self.layer_num = layer_num
391
421
 
392
422
  memory_saver_adapter = TorchMemorySaverAdapter.create(
393
423
  enable=enable_memory_saver
@@ -404,12 +434,20 @@ class MLATokenToKVPool(KVCache):
404
434
  for _ in range(layer_num)
405
435
  ]
406
436
 
437
+ self.layer_transfer_counter = None
438
+
407
439
  def get_key_buffer(self, layer_id: int):
440
+ if self.layer_transfer_counter is not None:
441
+ self.layer_transfer_counter.wait_until(layer_id)
442
+
408
443
  if self.store_dtype != self.dtype:
409
444
  return self.kv_buffer[layer_id].view(self.dtype)
410
445
  return self.kv_buffer[layer_id]
411
446
 
412
447
  def get_value_buffer(self, layer_id: int):
448
+ if self.layer_transfer_counter is not None:
449
+ self.layer_transfer_counter.wait_until(layer_id)
450
+
413
451
  if self.store_dtype != self.dtype:
414
452
  return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
415
453
  return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
@@ -432,6 +470,22 @@ class MLATokenToKVPool(KVCache):
432
470
  else:
433
471
  self.kv_buffer[layer_id][loc] = cache_k
434
472
 
473
+ def get_flat_data(self, indices):
474
+ # prepare a large chunk of contiguous data for efficient transfer
475
+ return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
476
+
477
+ @debug_timing
478
+ def transfer(self, indices, flat_data):
479
+ # transfer prepared data from host to device
480
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
481
+ for i in range(self.layer_num):
482
+ self.kv_buffer[i][indices] = flat_data[i]
483
+
484
+ def transfer_per_layer(self, indices, flat_data, layer_id):
485
+ # transfer prepared data from host to device
486
+ flat_data = flat_data.to(device=self.device, non_blocking=False)
487
+ self.kv_buffer[layer_id][indices] = flat_data
488
+
435
489
 
436
490
  class DoubleSparseTokenToKVPool(KVCache):
437
491
  def __init__(
@@ -508,6 +562,15 @@ class DoubleSparseTokenToKVPool(KVCache):
508
562
  self.v_buffer[layer_id][loc] = cache_v
509
563
  self.label_buffer[layer_id][loc] = cache_label
510
564
 
565
+ def get_flat_data(self, indices):
566
+ pass
567
+
568
+ def transfer(self, indices, flat_data):
569
+ pass
570
+
571
+ def transfer_per_layer(self, indices, flat_data, layer_id):
572
+ pass
573
+
511
574
 
512
575
  class MemoryStateInt(IntEnum):
513
576
  IDLE = 0
@@ -517,21 +580,28 @@ class MemoryStateInt(IntEnum):
517
580
  BACKUP = 4
518
581
 
519
582
 
520
- def synchronized(func):
521
- @wraps(func)
522
- def wrapper(self, *args, **kwargs):
523
- with self.lock:
524
- return func(self, *args, **kwargs)
583
+ def synchronized(debug_only=False):
584
+ def _decorator(func):
585
+ @wraps(func)
586
+ def wrapper(self, *args, **kwargs):
587
+ if (not debug_only) or self.debug:
588
+ return func(self, *args, **kwargs)
589
+ with self.lock:
590
+ return func(self, *args, **kwargs)
591
+ else:
592
+ return True
525
593
 
526
- return wrapper
594
+ return wrapper
527
595
 
596
+ return _decorator
528
597
 
529
- class MHATokenToKVPoolHost:
598
+
599
+ class HostKVCache(abc.ABC):
530
600
 
531
601
  def __init__(
532
602
  self,
533
603
  device_pool: MHATokenToKVPool,
534
- host_to_device_ratio: float = 3.0,
604
+ host_to_device_ratio: float,
535
605
  pin_memory: bool = False, # no need to use pin memory with the double buffering
536
606
  device: str = "cpu",
537
607
  ):
@@ -547,12 +617,7 @@ class MHATokenToKVPoolHost:
547
617
 
548
618
  self.size = int(device_pool.size * host_to_device_ratio)
549
619
  self.dtype = device_pool.store_dtype
550
- self.head_num = device_pool.head_num
551
- self.head_dim = device_pool.head_dim
552
- self.layer_num = device_pool.layer_num
553
- self.size_per_token = (
554
- self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
555
- )
620
+ self.size_per_token = self.get_size_per_token()
556
621
 
557
622
  # Verify there is enough available host memory.
558
623
  host_mem = psutil.virtual_memory()
@@ -571,123 +636,218 @@ class MHATokenToKVPoolHost:
571
636
  f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
572
637
  )
573
638
 
574
- self.kv_buffer = torch.zeros(
575
- (2, self.layer_num, self.size, self.head_num, self.head_dim),
576
- dtype=self.dtype,
577
- device=self.device,
578
- pin_memory=self.pin_memory,
579
- )
580
-
581
- # Initialize memory states and tracking structures.
582
- self.mem_state = torch.zeros(
583
- (self.size,), dtype=torch.uint8, device=self.device
584
- )
585
- self.free_slots = torch.arange(self.size, dtype=torch.int32)
586
- self.can_use_mem_size = self.size
639
+ self.kv_buffer = self.init_kv_buffer()
587
640
 
588
641
  # A lock for synchronized operations on memory allocation and state transitions.
589
642
  self.lock = threading.RLock()
643
+ self.debug = logger.isEnabledFor(logging.DEBUG)
644
+ self.clear()
590
645
 
591
- def get_flat_data(self, indices):
592
- return self.kv_buffer[:, :, indices]
646
+ @abc.abstractmethod
647
+ def get_size_per_token(self):
648
+ raise NotImplementedError()
593
649
 
594
- def assign_flat_data(self, indices, flat_data):
595
- self.kv_buffer[:, :, indices] = flat_data
650
+ @abc.abstractmethod
651
+ def init_kv_buffer(self):
652
+ raise NotImplementedError()
596
653
 
597
- @debug_timing
654
+ @abc.abstractmethod
598
655
  def transfer(self, indices, flat_data):
599
- # backup prepared data from device to host
600
- self.kv_buffer[:, :, indices] = flat_data.to(
601
- device=self.device, non_blocking=False
602
- )
656
+ raise NotImplementedError()
657
+
658
+ @abc.abstractmethod
659
+ def get_flat_data(self, indices):
660
+ raise NotImplementedError()
661
+
662
+ @abc.abstractmethod
663
+ def get_flat_data_by_layer(self, indices, layer_id):
664
+ raise NotImplementedError()
603
665
 
604
- @synchronized
666
+ @abc.abstractmethod
667
+ def assign_flat_data(self, indices, flat_data):
668
+ raise NotImplementedError()
669
+
670
+ @synchronized()
605
671
  def clear(self):
606
- self.mem_state.fill_(0)
607
- self.can_use_mem_size = self.size
608
- self.free_slots = torch.arange(self.size, dtype=torch.int32)
672
+ # Initialize memory states and tracking structures.
673
+ self.mem_state = torch.zeros(
674
+ (self.size,), dtype=torch.uint8, device=self.device
675
+ )
676
+ self.free_slots = torch.arange(self.size, dtype=torch.int64)
609
677
 
610
- @synchronized
611
- def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
612
- assert len(indices) > 0, "The indices should not be empty"
613
- states = self.mem_state[indices]
614
- assert (
615
- states == states[0]
616
- ).all(), "The memory slots should have the same state {}".format(states)
617
- return MemoryStateInt(states[0].item())
678
+ def available_size(self):
679
+ return len(self.free_slots)
618
680
 
619
- @synchronized
681
+ @synchronized()
620
682
  def alloc(self, need_size: int) -> torch.Tensor:
621
- if need_size > self.can_use_mem_size:
683
+ if need_size > self.available_size():
622
684
  return None
623
685
 
624
- # todo: de-fragementation
625
686
  select_index = self.free_slots[:need_size]
626
687
  self.free_slots = self.free_slots[need_size:]
627
688
 
628
- self.mem_state[select_index] = MemoryStateInt.RESERVED
629
- self.can_use_mem_size -= need_size
689
+ if self.debug:
690
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
630
691
 
631
692
  return select_index
632
693
 
633
- @synchronized
694
+ @synchronized()
695
+ def free(self, indices: torch.Tensor) -> int:
696
+ self.free_slots = torch.cat([self.free_slots, indices])
697
+ if self.debug:
698
+ self.mem_state[indices] = MemoryStateInt.IDLE
699
+ return len(indices)
700
+
701
+ @synchronized(debug_only=True)
702
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
703
+ assert len(indices) > 0, "The indices should not be empty"
704
+ states = self.mem_state[indices]
705
+ assert (
706
+ states == states[0]
707
+ ).all(), "The memory slots should have the same state {}".format(states)
708
+ return MemoryStateInt(states[0].item())
709
+
710
+ @synchronized(debug_only=True)
634
711
  def is_reserved(self, indices: torch.Tensor) -> bool:
635
712
  return self.get_state(indices) == MemoryStateInt.RESERVED
636
713
 
637
- @synchronized
714
+ @synchronized(debug_only=True)
638
715
  def is_protected(self, indices: torch.Tensor) -> bool:
639
716
  return self.get_state(indices) == MemoryStateInt.PROTECTED
640
717
 
641
- @synchronized
718
+ @synchronized(debug_only=True)
642
719
  def is_synced(self, indices: torch.Tensor) -> bool:
643
720
  return self.get_state(indices) == MemoryStateInt.SYNCED
644
721
 
645
- @synchronized
722
+ @synchronized(debug_only=True)
646
723
  def is_backup(self, indices: torch.Tensor) -> bool:
647
724
  return self.get_state(indices) == MemoryStateInt.BACKUP
648
725
 
649
- @synchronized
726
+ @synchronized(debug_only=True)
650
727
  def update_backup(self, indices: torch.Tensor):
651
- assert self.is_synced(indices), (
652
- f"The host memory slots should be in SYNCED state before turning into BACKUP. "
653
- f"Current state: {self.get_state(indices)}"
654
- )
728
+ if not self.is_synced(indices):
729
+ raise ValueError(
730
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
731
+ f"Current state: {self.get_state(indices)}"
732
+ )
655
733
  self.mem_state[indices] = MemoryStateInt.BACKUP
656
734
 
657
- @synchronized
735
+ @synchronized(debug_only=True)
658
736
  def update_synced(self, indices: torch.Tensor):
659
737
  self.mem_state[indices] = MemoryStateInt.SYNCED
660
738
 
661
- @synchronized
739
+ @synchronized(debug_only=True)
662
740
  def protect_write(self, indices: torch.Tensor):
663
- assert self.is_reserved(indices), (
664
- f"The host memory slots should be RESERVED before write operations. "
665
- f"Current state: {self.get_state(indices)}"
666
- )
741
+ if not self.is_reserved(indices):
742
+ raise ValueError(
743
+ f"The host memory slots should be RESERVED before write operations. "
744
+ f"Current state: {self.get_state(indices)}"
745
+ )
667
746
  self.mem_state[indices] = MemoryStateInt.PROTECTED
668
747
 
669
- @synchronized
748
+ @synchronized(debug_only=True)
670
749
  def protect_load(self, indices: torch.Tensor):
671
- assert self.is_backup(indices), (
672
- f"The host memory slots should be in BACKUP state before load operations. "
673
- f"Current state: {self.get_state(indices)}"
674
- )
750
+ if not self.is_backup(indices):
751
+ raise ValueError(
752
+ f"The host memory slots should be in BACKUP state before load operations. "
753
+ f"Current state: {self.get_state(indices)}"
754
+ )
675
755
  self.mem_state[indices] = MemoryStateInt.PROTECTED
676
756
 
677
- @synchronized
757
+ @synchronized(debug_only=True)
678
758
  def complete_io(self, indices: torch.Tensor):
679
- assert self.is_protected(indices), (
680
- f"The host memory slots should be PROTECTED during I/O operations. "
681
- f"Current state: {self.get_state(indices)}"
682
- )
759
+ if not self.is_protected(indices):
760
+ raise ValueError(
761
+ f"The host memory slots should be PROTECTED during I/O operations. "
762
+ f"Current state: {self.get_state(indices)}"
763
+ )
683
764
  self.mem_state[indices] = MemoryStateInt.SYNCED
684
765
 
685
- def available_size(self):
686
- return len(self.free_slots)
687
766
 
688
- @synchronized
689
- def free(self, indices: torch.Tensor) -> int:
690
- self.mem_state[indices] = MemoryStateInt.IDLE
691
- self.free_slots = torch.concat([self.free_slots, indices])
692
- self.can_use_mem_size += len(indices)
693
- return len(indices)
767
+ class MHATokenToKVPoolHost(HostKVCache):
768
+ def __init__(
769
+ self,
770
+ device_pool: MHATokenToKVPool,
771
+ host_to_device_ratio: float,
772
+ pin_memory: bool = False, # no need to use pin memory with the double buffering
773
+ device: str = "cpu",
774
+ ):
775
+ super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
776
+
777
+ def get_size_per_token(self):
778
+ self.head_num = self.device_pool.head_num
779
+ self.head_dim = self.device_pool.head_dim
780
+ self.layer_num = self.device_pool.layer_num
781
+
782
+ return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
783
+
784
+ def init_kv_buffer(self):
785
+ return torch.empty(
786
+ (2, self.layer_num, self.size, self.head_num, self.head_dim),
787
+ dtype=self.dtype,
788
+ device=self.device,
789
+ pin_memory=self.pin_memory,
790
+ )
791
+
792
+ @debug_timing
793
+ def transfer(self, indices, flat_data):
794
+ # backup prepared data from device to host
795
+ self.kv_buffer[:, :, indices] = flat_data.to(
796
+ device=self.device, non_blocking=False
797
+ )
798
+
799
+ def get_flat_data(self, indices):
800
+ return self.kv_buffer[:, :, indices]
801
+
802
+ def get_flat_data_by_layer(self, indices, layer_id):
803
+ return self.kv_buffer[:, layer_id, indices]
804
+
805
+ def assign_flat_data(self, indices, flat_data):
806
+ self.kv_buffer[:, :, indices] = flat_data
807
+
808
+
809
+ class MLATokenToKVPoolHost(HostKVCache):
810
+ def __init__(
811
+ self,
812
+ device_pool: MLATokenToKVPool,
813
+ host_to_device_ratio: float,
814
+ pin_memory: bool = False, # no need to use pin memory with the double buffering
815
+ device: str = "cpu",
816
+ ):
817
+ super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
818
+
819
+ def get_size_per_token(self):
820
+ self.kv_lora_rank = self.device_pool.kv_lora_rank
821
+ self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
822
+ self.layer_num = self.device_pool.layer_num
823
+
824
+ return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
825
+
826
+ def init_kv_buffer(self):
827
+ return torch.empty(
828
+ (
829
+ self.layer_num,
830
+ self.size,
831
+ 1,
832
+ self.kv_lora_rank + self.qk_rope_head_dim,
833
+ ),
834
+ dtype=self.dtype,
835
+ device=self.device,
836
+ pin_memory=self.pin_memory,
837
+ )
838
+
839
+ @debug_timing
840
+ def transfer(self, indices, flat_data):
841
+ # backup prepared data from device to host
842
+ self.kv_buffer[:, indices] = flat_data.to(
843
+ device=self.device, non_blocking=False
844
+ )
845
+
846
+ def get_flat_data(self, indices):
847
+ return self.kv_buffer[:, indices]
848
+
849
+ def get_flat_data_by_layer(self, indices, layer_id):
850
+ return self.kv_buffer[layer_id, indices]
851
+
852
+ def assign_flat_data(self, indices, flat_data):
853
+ self.kv_buffer[:, indices] = flat_data
@@ -272,12 +272,12 @@ class PagedTokenToKVPoolAllocator:
272
272
  def free_group_end(self):
273
273
  self.is_not_in_free_group = True
274
274
  if self.free_group:
275
- self.free(torch.concat(self.free_group))
275
+ self.free(torch.cat(self.free_group))
276
276
 
277
277
  def clear(self):
278
278
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
279
279
  self.free_pages = torch.arange(
280
280
  1, self.num_pages + 1, dtype=torch.int64, device=self.device
281
281
  )
282
- self.is_in_free_group = False
282
+ self.is_not_in_free_group = True
283
283
  self.free_group = []
@@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache):
140
140
  return (
141
141
  torch.empty(
142
142
  (0,),
143
- dtype=torch.int32,
143
+ dtype=torch.int64,
144
144
  device=self.device,
145
145
  ),
146
146
  self.root_node,
@@ -152,9 +152,9 @@ class RadixCache(BasePrefixCache):
152
152
 
153
153
  value, last_node = self._match_prefix_helper(self.root_node, key)
154
154
  if value:
155
- value = torch.concat(value)
155
+ value = torch.cat(value)
156
156
  else:
157
- value = torch.empty((0,), dtype=torch.int32, device=self.device)
157
+ value = torch.empty((0,), dtype=torch.int64, device=self.device)
158
158
  return value, last_node
159
159
 
160
160
  def insert(self, key: List, value=None):
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
317
317
  _dfs_helper(child)
318
318
 
319
319
  _dfs_helper(self.root_node)
320
- return torch.concat(values)
320
+ return torch.cat(values)
321
321
 
322
322
  ##### Internal Helper Functions #####
323
323