sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
7
 
8
8
  from sglang.srt.disaggregation.utils import DisaggregationMode
9
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
- from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
10
+ from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
11
11
  from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
12
12
 
13
13
  if TYPE_CHECKING:
@@ -126,7 +126,16 @@ class SchedulerOutputProcessorMixin:
126
126
  )
127
127
 
128
128
  if req.grammar is not None:
129
- req.grammar.accept_token(next_token_id)
129
+ # FIXME: this try-except block is for handling unexpected xgrammar issue.
130
+ try:
131
+ req.grammar.accept_token(next_token_id)
132
+ except ValueError as e:
133
+ # Grammar accept_token can raise ValueError if the token is not in the grammar.
134
+ # This can happen if the grammar is not set correctly or the token is invalid.
135
+ logger.error(
136
+ f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
137
+ )
138
+ self.abort_request(AbortReq(req.rid))
130
139
  req.grammar.finished = req.finished()
131
140
  else:
132
141
  # being chunked reqs' prefill is not finished
@@ -263,7 +272,16 @@ class SchedulerOutputProcessorMixin:
263
272
  )
264
273
 
265
274
  if req.grammar is not None and batch.spec_algorithm.is_none():
266
- req.grammar.accept_token(next_token_id)
275
+ # FIXME: this try-except block is for handling unexpected xgrammar issue.
276
+ try:
277
+ req.grammar.accept_token(next_token_id)
278
+ except ValueError as e:
279
+ # Grammar accept_token can raise ValueError if the token is not in the grammar.
280
+ # This can happen if the grammar is not set correctly or the token is invalid.
281
+ logger.error(
282
+ f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
283
+ )
284
+ self.abort_request(AbortReq(req.rid))
267
285
  req.grammar.finished = req.finished()
268
286
 
269
287
  self.set_next_batch_sampling_info_done(batch)
@@ -272,7 +290,7 @@ class SchedulerOutputProcessorMixin:
272
290
 
273
291
  self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
274
292
  if (
275
- self.attn_tp_rank == 0
293
+ self.current_scheduler_metrics_enabled()
276
294
  and self.forward_ct_decode % self.server_args.decode_log_interval == 0
277
295
  ):
278
296
  self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
62
62
  get_tokenizer,
63
63
  get_tokenizer_from_processor,
64
64
  )
65
+ from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
65
66
  from sglang.srt.managers.io_struct import (
66
67
  AbortReq,
67
68
  BatchEmbeddingOut,
@@ -242,11 +243,11 @@ class TokenizerManager:
242
243
  revision=server_args.revision,
243
244
  )
244
245
 
245
- # Initialize loaded loRA adapters with the initial lora paths in the server_args.
246
- # This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
247
- self.loaded_lora_adapters: Dict[str, str] = dict(
248
- self.server_args.lora_paths or {}
249
- )
246
+ # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
247
+ # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
248
+ # serves as the source of truth for available adapters and maps user-friendly LoRA names
249
+ # to internally used unique LoRA IDs.
250
+ self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
250
251
 
251
252
  # Store states
252
253
  self.no_create_loop = False
@@ -523,6 +524,10 @@ class TokenizerManager:
523
524
  else:
524
525
  mm_inputs = None
525
526
 
527
+ if self.server_args.enable_lora and obj.lora_path:
528
+ # Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
529
+ obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
530
+
526
531
  self._validate_one_request(obj, input_ids)
527
532
  return self._create_tokenized_object(
528
533
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -574,8 +579,6 @@ class TokenizerManager:
574
579
  "The server is not configured to enable custom logit processor. "
575
580
  "Please set `--enable-custom-logits-processor` to enable this feature."
576
581
  )
577
- if self.server_args.lora_paths and obj.lora_path:
578
- self._validate_lora_adapters(obj)
579
582
 
580
583
  def _validate_input_ids_in_vocab(
581
584
  self, input_ids: List[int], vocab_size: int
@@ -604,7 +607,7 @@ class TokenizerManager:
604
607
  sampling_kwargs = obj.sampling_params
605
608
  sampling_params = SamplingParams(**sampling_kwargs)
606
609
  sampling_params.normalize(self.tokenizer)
607
- sampling_params.verify()
610
+ sampling_params.verify(self.model_config.vocab_size)
608
611
 
609
612
  # Build return object
610
613
  if isinstance(obj, GenerateReqInput):
@@ -689,21 +692,6 @@ class TokenizerManager:
689
692
  "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
690
693
  )
691
694
 
692
- def _validate_lora_adapters(self, obj: GenerateReqInput):
693
- """Validate that the requested LoRA adapters are loaded."""
694
- requested_adapters = (
695
- set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
696
- )
697
- loaded_adapters = (
698
- self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
699
- )
700
- unloaded_adapters = requested_adapters - loaded_adapters
701
- if unloaded_adapters:
702
- raise ValueError(
703
- f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
704
- f"Loaded adapters: {loaded_adapters}."
705
- )
706
-
707
695
  def _send_one_request(
708
696
  self,
709
697
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -1037,6 +1025,10 @@ class TokenizerManager:
1037
1025
  _: Optional[fastapi.Request] = None,
1038
1026
  ) -> LoadLoRAAdapterReqOutput:
1039
1027
  self.auto_create_handle_loop()
1028
+ if not self.server_args.enable_lora:
1029
+ raise ValueError(
1030
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1031
+ )
1040
1032
 
1041
1033
  # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1042
1034
  # with dp_size > 1.
@@ -1050,8 +1042,18 @@ class TokenizerManager:
1050
1042
  )
1051
1043
 
1052
1044
  async with self.model_update_lock.writer_lock:
1045
+ # Generate new uniquely identifiable LoRARef object.
1046
+ new_adapter = LoRARef(
1047
+ lora_name=obj.lora_name,
1048
+ lora_path=obj.lora_path,
1049
+ )
1050
+
1051
+ # Register the new adapter in the registry.
1052
+ obj.lora_id = new_adapter.lora_id
1053
1053
  result = (await self.update_lora_adapter_communicator(obj))[0]
1054
- self.loaded_lora_adapters = result.loaded_adapters
1054
+ if result.success:
1055
+ await self.lora_registry.register(new_adapter)
1056
+
1055
1057
  return result
1056
1058
 
1057
1059
  async def unload_lora_adapter(
@@ -1060,6 +1062,14 @@ class TokenizerManager:
1060
1062
  _: Optional[fastapi.Request] = None,
1061
1063
  ) -> UnloadLoRAAdapterReqOutput:
1062
1064
  self.auto_create_handle_loop()
1065
+ if not self.server_args.enable_lora:
1066
+ raise ValueError(
1067
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1068
+ )
1069
+
1070
+ assert (
1071
+ obj.lora_name is not None
1072
+ ), "lora_name must be provided to unload LoRA adapter"
1063
1073
 
1064
1074
  # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1065
1075
  # with dp_size > 1.
@@ -1072,8 +1082,9 @@ class TokenizerManager:
1072
1082
  )
1073
1083
 
1074
1084
  async with self.model_update_lock.writer_lock:
1085
+ obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
1075
1086
  result = (await self.update_lora_adapter_communicator(obj))[0]
1076
- self.loaded_lora_adapters = result.loaded_adapters
1087
+
1077
1088
  return result
1078
1089
 
1079
1090
  async def get_weights_by_name(
@@ -1301,7 +1312,7 @@ class TokenizerManager:
1301
1312
  filename = os.path.join(
1302
1313
  self.crash_dump_folder,
1303
1314
  os.getenv("HOSTNAME", None),
1304
- f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
1315
+ f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
1305
1316
  )
1306
1317
 
1307
1318
  os.makedirs(os.path.dirname(filename), exist_ok=True)
@@ -1359,7 +1370,7 @@ class TokenizerManager:
1359
1370
  while True:
1360
1371
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1361
1372
  self._result_dispatcher(recv_obj)
1362
- self.last_receive_tstamp = time.time()
1373
+ self.last_receive_tstamp = time.perf_counter()
1363
1374
 
1364
1375
  def _handle_batch_output(
1365
1376
  self,
@@ -174,6 +174,20 @@ class TpModelWorker:
174
174
  self.model_runner.token_to_kv_pool.size,
175
175
  )
176
176
 
177
+ @property
178
+ def sliding_window_size(self) -> Optional[int]:
179
+ return self.model_runner.sliding_window_size
180
+
181
+ @property
182
+ def is_hybrid(self) -> bool:
183
+ return self.model_runner.is_hybrid is not None
184
+
185
+ def get_tokens_per_layer_info(self):
186
+ return (
187
+ self.model_runner.full_max_total_num_tokens,
188
+ self.model_runner.swa_max_total_num_tokens,
189
+ )
190
+
177
191
  def get_pad_input_ids_func(self):
178
192
  return getattr(self.model_runner.model, "pad_input_ids", None)
179
193
 
@@ -279,11 +293,9 @@ class TpModelWorker:
279
293
  return parameter
280
294
 
281
295
  def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
282
- result = self.model_runner.load_lora_adapter(
283
- recv_req.lora_name, recv_req.lora_path
284
- )
296
+ result = self.model_runner.load_lora_adapter(recv_req.to_ref())
285
297
  return result
286
298
 
287
299
  def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
288
- result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
300
+ result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
289
301
  return result
@@ -102,6 +102,17 @@ class TpModelWorkerClient:
102
102
  def get_worker_info(self):
103
103
  return self.worker.get_worker_info()
104
104
 
105
+ def get_tokens_per_layer_info(self):
106
+ return self.worker.get_tokens_per_layer_info()
107
+
108
+ @property
109
+ def sliding_window_size(self) -> Optional[int]:
110
+ return self.worker.sliding_window_size
111
+
112
+ @property
113
+ def is_hybrid(self) -> bool:
114
+ return self.worker.is_hybrid
115
+
105
116
  def get_pad_input_ids_func(self):
106
117
  return self.worker.get_pad_input_ids_func()
107
118
 
@@ -51,28 +51,24 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
51
51
  self._kvcache = kvcache
52
52
 
53
53
  self.free_pages = None
54
+ self.release_pages = None
54
55
  self.is_not_in_free_group = True
55
56
  self.free_group = []
56
57
 
57
58
  def debug_print(self) -> str:
58
59
  return ""
59
60
 
60
- def log_usage(self, evictable_size: int = 0):
61
- num_used = self.size - (self.available_size() + evictable_size)
62
- msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
63
- return msg, num_used
64
-
65
61
  def available_size(self):
66
- return len(self.free_pages) * self.page_size
62
+ return (len(self.free_pages) + len(self.release_pages)) * self.page_size
67
63
 
68
64
  def get_kvcache(self):
69
65
  return self._kvcache
70
66
 
71
- def restore_state(self, free_pages):
72
- self.free_pages = free_pages
67
+ def restore_state(self, state):
68
+ self.free_pages, self.release_pages = state
73
69
 
74
70
  def backup_state(self):
75
- return self.free_pages
71
+ return (self.free_pages, self.release_pages)
76
72
 
77
73
  def free_group_begin(self):
78
74
  self.is_not_in_free_group = False
@@ -83,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
83
79
  if self.free_group:
84
80
  self.free(torch.cat(self.free_group))
85
81
 
82
+ def merge_and_sort_free(self):
83
+ if len(self.release_pages) > 0:
84
+ self.free_pages = torch.cat((self.free_pages, self.release_pages))
85
+ self.free_pages, _ = torch.sort(self.free_pages)
86
+ self.release_pages = torch.empty(
87
+ (0,), dtype=self.release_pages.dtype, device=self.device
88
+ )
89
+
86
90
  def get_cpu_copy(self, *args, **kwargs):
87
91
  # FIXME: reuse the get_cpu_copy after paged allocator is implemented
88
92
  raise NotImplementedError()
@@ -124,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
124
128
  )
125
129
  self.is_not_in_free_group = True
126
130
  self.free_group = []
131
+ self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
127
132
 
128
133
  def available_size(self):
129
134
  # To avoid minor "len(free_pages) * 1" overhead
130
- return len(self.free_pages)
135
+ return len(self.free_pages) + len(self.release_pages)
131
136
 
132
137
  def alloc(self, need_size: int):
138
+ if need_size > len(self.free_pages):
139
+ self.merge_and_sort_free()
133
140
  if need_size > len(self.free_pages):
134
141
  return None
135
142
 
@@ -142,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
142
149
  return
143
150
 
144
151
  if self.is_not_in_free_group:
145
- self.free_pages = torch.cat((self.free_pages, free_index))
152
+ self.release_pages = torch.cat((self.release_pages, free_index))
146
153
  else:
147
154
  self.free_group.append(free_index)
148
155
 
@@ -190,7 +197,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
190
197
  self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
191
198
 
192
199
  def available_size(self):
193
- return min(self.full_available_size(), self.swa_available_size())
200
+ raise NotImplementedError()
194
201
 
195
202
  def full_available_size(self):
196
203
  return self.full_attn_allocator.available_size()
@@ -214,16 +221,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
214
221
  )
215
222
  return msg
216
223
 
217
- def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
218
- used_full = self.size_full - (self.full_available_size() + full_evictable_size)
219
- used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
220
- msg = (
221
- f"#token: full={used_full}, swa={used_swa}, "
222
- f"token usage: full={used_full / self.size_full:.2f}, "
223
- f"swa={used_swa / self.size_swa:.2f}, "
224
- )
225
- return msg, used_full
226
-
227
224
  def get_kvcache(self):
228
225
  return self._kvcache
229
226
 
@@ -436,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
436
433
  ), "The allocation size should be page-aligned"
437
434
 
438
435
  num_pages = need_size // self.page_size
436
+ if num_pages > len(self.free_pages):
437
+ self.merge_and_sort_free()
439
438
  if num_pages > len(self.free_pages):
440
439
  return None
441
440
 
@@ -461,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
461
460
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
462
461
  )
463
462
 
463
+ estimated_num_new_pages = (
464
+ (
465
+ (seq_lens + self.page_size - 1) // self.page_size
466
+ - (prefix_lens + self.page_size - 1) // self.page_size
467
+ )
468
+ .sum()
469
+ .item()
470
+ )
471
+ if estimated_num_new_pages > len(self.free_pages):
472
+ self.merge_and_sort_free()
473
+
464
474
  bs = len(prefix_lens)
465
475
  out_indices = torch.empty(
466
476
  (extend_num_tokens,), dtype=torch.int64, device=self.device
@@ -498,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
498
508
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
499
509
  )
500
510
 
511
+ estimated_num_new_pages = (
512
+ (
513
+ (seq_lens + self.page_size - 1) // self.page_size
514
+ - (seq_lens - 1 + self.page_size - 1) // self.page_size
515
+ )
516
+ .sum()
517
+ .item()
518
+ )
519
+ if estimated_num_new_pages > len(self.free_pages):
520
+ self.merge_and_sort_free()
521
+
501
522
  bs = len(seq_lens)
502
523
  out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
503
524
  alloc_decode_kernel[(bs,)](
@@ -526,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
526
547
 
527
548
  if self.is_not_in_free_group:
528
549
  free_page_indices = torch.unique(free_index // self.page_size)
529
- self.free_pages = torch.cat((free_page_indices, self.free_pages))
550
+ self.release_pages = torch.cat((free_page_indices, self.release_pages))
530
551
  else:
531
552
  self.free_group.append(free_index)
532
553
 
@@ -540,6 +561,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
540
561
  )
541
562
  self.is_not_in_free_group = True
542
563
  self.free_group = []
564
+ self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
565
+
566
+ def get_cpu_copy(self, indices):
567
+ return self._kvcache.get_cpu_copy(indices)
568
+
569
+ def load_cpu_copy(self, kv_cache_cpu, indices):
570
+ return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
543
571
 
544
572
 
545
573
  def alloc_extend_kernel_ascend(
@@ -642,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
642
670
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
643
671
  )
644
672
 
673
+ estimated_num_new_pages = (
674
+ (
675
+ (seq_lens + self.page_size - 1) // self.page_size
676
+ - (prefix_lens + self.page_size - 1) // self.page_size
677
+ )
678
+ .sum()
679
+ .item()
680
+ )
681
+ if estimated_num_new_pages > len(self.free_pages):
682
+ self.merge_and_sort_free()
683
+
645
684
  bs = len(prefix_lens)
646
685
  out_indices = torch.empty(
647
686
  (extend_num_tokens,), dtype=torch.int32, device=self.device
@@ -677,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
677
716
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
678
717
  )
679
718
 
719
+ estimated_num_new_pages = (
720
+ (
721
+ (seq_lens + self.page_size - 1) // self.page_size
722
+ - (seq_lens - 1 + self.page_size - 1) // self.page_size
723
+ )
724
+ .sum()
725
+ .item()
726
+ )
727
+ if estimated_num_new_pages > len(self.free_pages):
728
+ self.merge_and_sort_free()
729
+
680
730
  bs = len(seq_lens)
681
731
  out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
682
732
 
@@ -701,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
701
751
  def clear(self):
702
752
  super().clear()
703
753
  self.free_pages = self.free_pages.to(torch.int32)
754
+ self.release_pages = self.release_pages.to(torch.int32)
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
2
+ from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
3
3
 
4
4
  import torch
5
5
 
@@ -56,15 +56,27 @@ class BasePrefixCache(ABC):
56
56
  pass
57
57
 
58
58
  @abstractmethod
59
- def dec_lock_ref(self, node: Any):
59
+ def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
60
60
  pass
61
61
 
62
62
  def evictable_size(self):
63
63
  return 0
64
64
 
65
+ def full_evictable_size(self):
66
+ return 0
67
+
68
+ def swa_evictable_size(self):
69
+ return 0
70
+
65
71
  def protected_size(self):
66
72
  return 0
67
73
 
74
+ def full_protected_size(self):
75
+ return 0
76
+
77
+ def swa_protected_size(self):
78
+ return 0
79
+
68
80
  def total_size(self):
69
81
  raise NotImplementedError()
70
82
 
@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache):
61
61
  def inc_lock_ref(self, node: Any):
62
62
  return 0
63
63
 
64
- def dec_lock_ref(self, node: Any):
64
+ def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
65
65
  return 0
66
66
 
67
67
  def pretty_print(self):
@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache):
80
80
  super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
81
81
  assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
82
82
 
83
- def evict(
83
+ def evict_swa(
84
84
  self,
85
85
  req: Req,
86
86
  prelen: int,
@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache):
95
95
  ]
96
96
  self.token_to_kv_pool_allocator.free_swa(free_slots)
97
97
  req.evicted_seqlen_local = new_evicted_seqlen_local
98
+
99
+ def evict(self, num_tokens: int):
100
+ pass
@@ -0,0 +1,168 @@
1
+ import hashlib
2
+ import logging
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+ from typing import List, Optional
6
+
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ from sglang.srt.distributed import (
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ )
16
+
17
+
18
+ def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
19
+ hasher = hashlib.sha256()
20
+
21
+ if prior_hash:
22
+ hasher.update(bytes.fromhex(prior_hash))
23
+
24
+ for t in token_ids:
25
+ hasher.update(t.to_bytes(4, byteorder="little", signed=False))
26
+
27
+ return hasher.hexdigest()
28
+
29
+
30
+ class HiCacheStorage(ABC):
31
+ """
32
+ HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
33
+ It abstracts the underlying storage mechanism, allowing different implementations to be used.
34
+ """
35
+
36
+ # todo, translate tensor object access for different TP ranks
37
+ # potentially pass model and TP configs into storage backend
38
+ # todo, the page size of storage backend does not have to be the same as the same as host memory pool
39
+
40
+ @abstractmethod
41
+ def get(
42
+ self, key: str, target_location: Optional[torch.Tensor] = None
43
+ ) -> torch.Tensor | None:
44
+ """
45
+ Retrieve the value associated with the given key.
46
+ Returns None if the key does not exist.
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ def batch_get(
52
+ self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
53
+ ) -> List[torch.Tensor | None]:
54
+ """
55
+ Retrieve values for multiple keys.
56
+ Returns a list of tensors or None for each key.
57
+ """
58
+ pass
59
+
60
+ @abstractmethod
61
+ def set(self, key, value) -> bool:
62
+ """
63
+ Store the value associated with the given key.
64
+ Returns True if the operation was successful, False otherwise.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
70
+ """
71
+ Store multiple key-value pairs.
72
+ Returns True if all operations were successful, False otherwise.
73
+ """
74
+ pass
75
+
76
+ @abstractmethod
77
+ def exists(self, key: str) -> bool:
78
+ """
79
+ Check if the key exists in the storage.
80
+ Returns True if the key exists, False otherwise.
81
+ """
82
+ pass
83
+
84
+
85
+ class HiCacheFile(HiCacheStorage):
86
+
87
+ def __init__(self, file_path: str = "/tmp/hicache"):
88
+ self.file_path = file_path
89
+ tp_rank = get_tensor_model_parallel_rank()
90
+ tp_size = get_tensor_model_parallel_world_size()
91
+ self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
92
+ if not os.path.exists(self.file_path) and tp_rank == 0:
93
+ os.makedirs(self.file_path)
94
+ logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
95
+
96
+ def _get_suffixed_key(self, key: str) -> str:
97
+ return key + self.tp_suffix
98
+
99
+ def get(
100
+ self, key: str, target_location: Optional[torch.Tensor] = None
101
+ ) -> torch.Tensor | None:
102
+ key = self._get_suffixed_key(key)
103
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
104
+ try:
105
+ # todo: fixing the target_location logic to enable in-place loading
106
+ loaded_tensor = torch.load(tensor_path)
107
+ if isinstance(loaded_tensor, torch.Tensor):
108
+ return loaded_tensor
109
+ else:
110
+ logger.error(f"Loaded data for key {key} is not a tensor.")
111
+ return None
112
+ except FileNotFoundError:
113
+ return None
114
+
115
+ def batch_get(
116
+ self,
117
+ keys: List[str],
118
+ target_locations: Optional[List[torch.Tensor]] = None,
119
+ ) -> List[torch.Tensor | None]:
120
+ return [
121
+ self.get(key, target_location)
122
+ for key, target_location in zip(
123
+ keys, target_locations or [None] * len(keys)
124
+ )
125
+ ]
126
+
127
+ def set(self, key: str, value: torch.Tensor) -> bool:
128
+ key = self._get_suffixed_key(key)
129
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
130
+ if self.exists(key):
131
+ logger.debug(f"Key {key} already exists. Skipped.")
132
+ return True
133
+ try:
134
+ torch.save(value, tensor_path)
135
+ return True
136
+ except Exception as e:
137
+ logger.error(f"Failed to save tensor {key}: {e}")
138
+ return False
139
+
140
+ def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
141
+ for key, value in zip(keys, values):
142
+ if not self.set(key, value):
143
+ return False
144
+ return True
145
+
146
+ def exists(self, key: str) -> bool:
147
+ key = self._get_suffixed_key(key)
148
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
149
+ return os.path.exists(tensor_path)
150
+
151
+ def delete(self, key: str) -> None:
152
+ key = self._get_suffixed_key(key)
153
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
154
+ try:
155
+ os.remove(tensor_path)
156
+ except FileNotFoundError:
157
+ logger.warning(f"Key {key} does not exist. Cannot delete.")
158
+ return
159
+
160
+ def clear(self) -> None:
161
+ try:
162
+ for filename in os.listdir(self.file_path):
163
+ file_path = os.path.join(self.file_path, filename)
164
+ if os.path.isfile(file_path):
165
+ os.remove(file_path)
166
+ logger.info("Cleared all entries in HiCacheFile storage.")
167
+ except Exception as e:
168
+ logger.error(f"Failed to clear HiCacheFile storage: {e}")