sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@
16
16
  import asyncio
17
17
  import copy
18
18
  import dataclasses
19
- import json
20
19
  import logging
21
20
  import os
22
21
  import pickle
@@ -49,11 +48,9 @@ from fastapi import BackgroundTasks
49
48
 
50
49
  from sglang.srt.aio_rwlock import RWLock
51
50
  from sglang.srt.configs.model_config import ModelConfig
51
+ from sglang.srt.disaggregation.conn import KVBootstrapServer
52
+ from sglang.srt.disaggregation.utils import DisaggregationMode
52
53
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
53
- from sglang.srt.managers.image_processor import (
54
- get_dummy_image_processor,
55
- get_image_processor,
56
- )
57
54
  from sglang.srt.managers.io_struct import (
58
55
  AbortReq,
59
56
  BatchEmbeddingOut,
@@ -63,6 +60,8 @@ from sglang.srt.managers.io_struct import (
63
60
  CloseSessionReqInput,
64
61
  ConfigureLoggingReq,
65
62
  EmbeddingReqInput,
63
+ ExpertDistributionReq,
64
+ ExpertDistributionReqOutput,
66
65
  FlushCacheReq,
67
66
  GenerateReqInput,
68
67
  GetInternalStateReq,
@@ -91,6 +90,11 @@ from sglang.srt.managers.io_struct import (
91
90
  UpdateWeightsFromTensorReqInput,
92
91
  UpdateWeightsFromTensorReqOutput,
93
92
  )
93
+ from sglang.srt.managers.multimodal_processor import (
94
+ get_dummy_processor,
95
+ get_mm_processor,
96
+ import_processors,
97
+ )
94
98
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
95
99
  from sglang.srt.sampling.sampling_params import SamplingParams
96
100
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -168,27 +172,33 @@ class TokenizerManager:
168
172
  self.context_len = self.model_config.context_len
169
173
  self.image_token_id = self.model_config.image_token_id
170
174
 
171
- # Create image processor placeholder
172
- self.image_processor = get_dummy_image_processor()
175
+ if self.model_config.is_multimodal:
176
+ import_processors()
177
+ _processor = get_processor(
178
+ server_args.tokenizer_path,
179
+ tokenizer_mode=server_args.tokenizer_mode,
180
+ trust_remote_code=server_args.trust_remote_code,
181
+ revision=server_args.revision,
182
+ )
173
183
 
174
- # Create tokenizer
175
- if server_args.skip_tokenizer_init:
176
- self.tokenizer = self.processor = None
177
- else:
178
- if self.model_config.is_multimodal:
179
- self.processor = get_processor(
180
- server_args.tokenizer_path,
181
- tokenizer_mode=server_args.tokenizer_mode,
182
- trust_remote_code=server_args.trust_remote_code,
183
- revision=server_args.revision,
184
- )
184
+ # We want to parallelize the image pre-processing so we create an executor for it
185
+ # We create mm_processor for any skip_tokenizer_init to make sure we still encode
186
+ # images even with skip_tokenizer_init=False.
187
+ self.mm_processor = get_mm_processor(
188
+ self.model_config.hf_config, server_args, _processor
189
+ )
190
+
191
+ if server_args.skip_tokenizer_init:
192
+ self.tokenizer = self.processor = None
193
+ else:
194
+ self.processor = _processor
185
195
  self.tokenizer = self.processor.tokenizer
186
196
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
197
+ else:
198
+ self.mm_processor = get_dummy_processor()
187
199
 
188
- # We want to parallelize the image pre-processing so we create an executor for it
189
- self.image_processor = get_image_processor(
190
- self.model_config.hf_config, server_args, self.processor
191
- )
200
+ if server_args.skip_tokenizer_init:
201
+ self.tokenizer = self.processor = None
192
202
  else:
193
203
  self.tokenizer = get_tokenizer(
194
204
  server_args.tokenizer_path,
@@ -251,10 +261,12 @@ class TokenizerManager:
251
261
  self.start_profile_communicator = _Communicator(
252
262
  self.send_to_scheduler, server_args.dp_size
253
263
  )
254
- self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
255
264
  self.get_internal_state_communicator = _Communicator(
256
265
  self.send_to_scheduler, server_args.dp_size
257
266
  )
267
+ self.expert_distribution_communicator = _Communicator(
268
+ self.send_to_scheduler, server_args.dp_size
269
+ )
258
270
 
259
271
  self._result_dispatcher = TypeBasedDispatcher(
260
272
  [
@@ -304,10 +316,24 @@ class TokenizerManager:
304
316
  GetInternalStateReqOutput,
305
317
  self.get_internal_state_communicator.handle_recv,
306
318
  ),
319
+ (
320
+ ExpertDistributionReqOutput,
321
+ self.expert_distribution_communicator.handle_recv,
322
+ ),
307
323
  (HealthCheckOutput, lambda x: None),
308
324
  ]
309
325
  )
310
326
 
327
+ self.disaggregation_mode = DisaggregationMode(
328
+ self.server_args.disaggregation_mode
329
+ )
330
+ # for disaggregtion, start kv boostrap server on prefill
331
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
332
+ # only start bootstrap server on prefill tm
333
+ self.bootstrap_server = KVBootstrapServer(
334
+ self.server_args.disaggregation_bootstrap_port
335
+ )
336
+
311
337
  async def generate_request(
312
338
  self,
313
339
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -372,7 +398,7 @@ class TokenizerManager:
372
398
  )
373
399
  input_ids = self.tokenizer.encode(input_text)
374
400
 
375
- image_inputs: Dict = await self.image_processor.process_images_async(
401
+ image_inputs: Dict = await self.mm_processor.process_mm_data_async(
376
402
  obj.image_data, input_text or input_ids, obj, self.max_req_input_len
377
403
  )
378
404
  if image_inputs and "input_ids" in image_inputs:
@@ -620,6 +646,15 @@ class TokenizerManager:
620
646
  req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
621
647
  self.send_to_scheduler.send_pyobj(req)
622
648
 
649
+ async def start_expert_distribution_record(self):
650
+ await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
651
+
652
+ async def stop_expert_distribution_record(self):
653
+ await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
654
+
655
+ async def dump_expert_distribution_record(self):
656
+ await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
657
+
623
658
  async def update_weights_from_disk(
624
659
  self,
625
660
  obj: UpdateWeightFromDiskReqInput,
@@ -132,6 +132,9 @@ class TpModelWorker:
132
132
  )[0]
133
133
  set_random_seed(self.random_seed)
134
134
 
135
+ # A reference make this class has the same member as TpModelWorkerClient
136
+ self.worker = self
137
+
135
138
  def get_worker_info(self):
136
139
  return (
137
140
  self.max_total_num_tokens,
@@ -214,7 +217,7 @@ class TpModelWorker:
214
217
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
215
218
  success, message = self.model_runner.update_weights_from_tensor(
216
219
  named_tensors=MultiprocessingSerializer.deserialize(
217
- recv_req.serialized_named_tensors
220
+ recv_req.serialized_named_tensors[self.tp_rank]
218
221
  ),
219
222
  load_format=recv_req.load_format,
220
223
  )
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
33
33
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
34
34
  from sglang.srt.managers.tp_worker import TpModelWorker
35
35
  from sglang.srt.server_args import ServerArgs
36
- from sglang.srt.utils import get_compiler_backend
36
+ from sglang.srt.utils import DynamicGradMode, get_compiler_backend
37
37
  from sglang.utils import get_exception_traceback
38
38
 
39
39
  logger = logging.getLogger(__name__)
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
69
69
  self.future_token_ids_ct = 0
70
70
  self.future_token_ids_limit = self.max_running_requests * 3
71
71
  self.future_token_ids_map = torch.empty(
72
- (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
72
+ (self.max_running_requests * 5,), dtype=torch.int64, device=self.device
73
73
  )
74
74
 
75
75
  # Launch threads
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
115
115
  logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
116
116
  self.parent_process.send_signal(signal.SIGQUIT)
117
117
 
118
- @torch.no_grad()
118
+ @DynamicGradMode()
119
119
  def forward_thread_func_(self):
120
120
  batch_pt = 0
121
121
  batch_lists = [None] * 2
@@ -1,6 +1,11 @@
1
+ import json
1
2
  import logging
3
+ import time
4
+ from collections import defaultdict
2
5
  from http import HTTPStatus
3
- from typing import Optional
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import torch
4
9
 
5
10
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
11
 
@@ -8,7 +8,10 @@ import torch
8
8
 
9
9
  from sglang.srt.managers.cache_controller import HiCacheController
10
10
  from sglang.srt.mem_cache.memory_pool import (
11
+ MHATokenToKVPool,
11
12
  MHATokenToKVPoolHost,
13
+ MLATokenToKVPool,
14
+ MLATokenToKVPoolHost,
12
15
  ReqToTokenPool,
13
16
  TokenToKVPoolAllocator,
14
17
  )
@@ -26,14 +29,24 @@ class HiRadixCache(RadixCache):
26
29
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
27
30
  tp_cache_group: torch.distributed.ProcessGroup,
28
31
  page_size: int,
32
+ hicache_ratio: float,
29
33
  ):
30
34
  if page_size != 1:
31
35
  raise ValueError(
32
36
  "Page size larger than 1 is not yet supported in HiRadixCache."
33
37
  )
34
- self.token_to_kv_pool_host = MHATokenToKVPoolHost(
35
- token_to_kv_pool_allocator.get_kvcache()
36
- )
38
+ self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
39
+ if isinstance(self.kv_cache, MHATokenToKVPool):
40
+ self.token_to_kv_pool_host = MHATokenToKVPoolHost(
41
+ self.kv_cache, hicache_ratio
42
+ )
43
+ elif isinstance(self.kv_cache, MLATokenToKVPool):
44
+ self.token_to_kv_pool_host = MLATokenToKVPoolHost(
45
+ self.kv_cache, hicache_ratio
46
+ )
47
+ else:
48
+ raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
49
+
37
50
  self.tp_group = tp_cache_group
38
51
  self.page_size = page_size
39
52
 
@@ -295,9 +308,9 @@ class HiRadixCache(RadixCache):
295
308
 
296
309
  value, last_node = self._match_prefix_helper(self.root_node, key)
297
310
  if value:
298
- value = torch.concat(value)
311
+ value = torch.cat(value)
299
312
  else:
300
- value = torch.tensor([], dtype=torch.int32)
313
+ value = torch.tensor([], dtype=torch.int64)
301
314
 
302
315
  last_node_global = last_node
303
316
  while last_node.evicted:
@@ -317,13 +330,11 @@ class HiRadixCache(RadixCache):
317
330
  prefix_len = _key_match(child.key, key)
318
331
  if prefix_len < len(child.key):
319
332
  new_node = self._split_node(child.key, child, prefix_len)
320
- self.inc_hit_count(new_node)
321
333
  if not new_node.evicted:
322
334
  value.append(new_node.value)
323
335
  node = new_node
324
336
  break
325
337
  else:
326
- self.inc_hit_count(child)
327
338
  if not child.evicted:
328
339
  value.append(child.value)
329
340
  node = child