sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,7 @@ from sglang.srt.managers.io_struct import (
78
78
  ExpertDistributionReqOutput,
79
79
  FlushCacheReqInput,
80
80
  FlushCacheReqOutput,
81
+ FreezeGCReq,
81
82
  GenerateReqInput,
82
83
  GetInternalStateReq,
83
84
  GetInternalStateReqOutput,
@@ -122,7 +123,9 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector
122
123
  from sglang.srt.sampling.sampling_params import SamplingParams
123
124
  from sglang.srt.server_args import PortArgs, ServerArgs
124
125
  from sglang.srt.utils import (
126
+ configure_gc_warning,
125
127
  dataclass_to_string_truncated,
128
+ freeze_gc,
126
129
  get_bool_env_var,
127
130
  get_zmq_socket,
128
131
  kill_process_tree,
@@ -298,7 +301,7 @@ class TokenizerManager:
298
301
  # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
299
302
  # serves as the source of truth for available adapters and maps user-friendly LoRA names
300
303
  # to internally used unique LoRA IDs.
301
- self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
304
+ self.lora_registry = LoRARegistry(self.server_args.lora_paths)
302
305
  # Lock to serialize LoRA update operations.
303
306
  # Please note that, unlike `model_update_lock`, this does not block inference, allowing
304
307
  # LoRA updates and inference to overlap.
@@ -352,6 +355,10 @@ class TokenizerManager:
352
355
  collect_tokens_histogram=self.server_args.collect_tokens_histogram,
353
356
  )
354
357
 
358
+ # Configure GC warning
359
+ if self.server_args.gc_warning_threshold_secs > 0.0:
360
+ configure_gc_warning(self.server_args.gc_warning_threshold_secs)
361
+
355
362
  # Communicators
356
363
  self.init_weights_update_group_communicator = _Communicator(
357
364
  self.send_to_scheduler, server_args.dp_size
@@ -446,6 +453,10 @@ class TokenizerManager:
446
453
  ProfileReqOutput,
447
454
  self.profile_communicator.handle_recv,
448
455
  ),
456
+ (
457
+ FreezeGCReq,
458
+ lambda x: None,
459
+ ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
449
460
  (
450
461
  GetInternalStateReqOutput,
451
462
  self.get_internal_state_communicator.handle_recv,
@@ -565,14 +576,24 @@ class TokenizerManager:
565
576
  self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
566
577
  ) -> None:
567
578
  """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
579
+ # FIXME: unify the length validation logic with the one in the scheduler.
580
+ _max_req_len = self.context_len
568
581
 
569
582
  input_token_num = len(input_ids) if input_ids is not None else 0
570
- # Check if input alone exceeds context length
571
583
  if input_token_num >= self.context_len:
572
- raise ValueError(
573
- f"The input ({input_token_num} tokens) is longer than the "
574
- f"model's context length ({self.context_len} tokens)."
575
- )
584
+ if self.server_args.allow_auto_truncate:
585
+ logger.warning(
586
+ f"The input ({input_token_num} tokens) is longer than the "
587
+ f"model's context length ({self.context_len} tokens). "
588
+ "Truncating the input."
589
+ )
590
+ del input_ids[_max_req_len:]
591
+ input_token_num = len(input_ids)
592
+ else:
593
+ raise ValueError(
594
+ f"The input ({input_token_num} tokens) is longer than the "
595
+ f"model's context length ({self.context_len} tokens)."
596
+ )
576
597
 
577
598
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
578
599
  raise ValueError(
@@ -584,17 +605,27 @@ class TokenizerManager:
584
605
  max_new_tokens = obj.sampling_params.get("max_new_tokens")
585
606
  if (
586
607
  max_new_tokens is not None
587
- and (max_new_tokens + input_token_num) >= self.context_len
608
+ and (max_new_tokens + input_token_num) >= _max_req_len
588
609
  ):
589
- total_tokens = max_new_tokens + input_token_num
590
- error_msg = (
591
- f"Requested token count exceeds the model's maximum context length "
592
- f"of {self.context_len} tokens. You requested a total of {total_tokens} "
593
- f"tokens: {input_token_num} tokens from the input messages and "
594
- f"{max_new_tokens} tokens for the completion. Please reduce the number "
595
- f"of tokens in the input messages or the completion to fit within the limit."
596
- )
597
- raise ValueError(error_msg)
610
+ if self.server_args.allow_auto_truncate:
611
+ logger.warning(
612
+ f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
613
+ f"exceeds the model's context length ({self.context_len} tokens). "
614
+ "Truncating max_new_tokens."
615
+ )
616
+ obj.sampling_params["max_new_tokens"] = max(
617
+ 0, _max_req_len - input_token_num
618
+ )
619
+ else:
620
+ total_tokens = max_new_tokens + input_token_num
621
+ error_msg = (
622
+ f"Requested token count exceeds the model's maximum context length "
623
+ f"of {self.context_len} tokens. You requested a total of {total_tokens} "
624
+ f"tokens: {input_token_num} tokens from the input messages and "
625
+ f"{max_new_tokens} tokens for the completion. Please reduce the number "
626
+ f"of tokens in the input messages or the completion to fit within the limit."
627
+ )
628
+ raise ValueError(error_msg)
598
629
 
599
630
  if isinstance(obj, GenerateReqInput):
600
631
  if (
@@ -782,15 +813,17 @@ class TokenizerManager:
782
813
  ):
783
814
  raise ValueError(finish_reason["message"])
784
815
 
785
- if (
786
- finish_reason.get("type") == "abort"
787
- and finish_reason.get("status_code")
788
- == HTTPStatus.SERVICE_UNAVAILABLE
816
+ if finish_reason.get("type") == "abort" and finish_reason.get(
817
+ "status_code"
818
+ ) in (
819
+ HTTPStatus.SERVICE_UNAVAILABLE,
820
+ HTTPStatus.INTERNAL_SERVER_ERROR,
789
821
  ):
790
822
  # This is an abort request initiated by scheduler.
791
823
  # Delete the key to prevent resending abort request to the scheduler and
792
824
  # to ensure aborted request state is cleaned up.
793
- del self.rid_to_state[state.obj.rid]
825
+ if state.obj.rid in self.rid_to_state:
826
+ del self.rid_to_state[state.obj.rid]
794
827
 
795
828
  # Mark ongoing LoRA request as finished.
796
829
  if self.server_args.enable_lora and state.obj.lora_path:
@@ -1337,6 +1370,12 @@ class TokenizerManager:
1337
1370
  logging.info(f"Config logging: {obj=}")
1338
1371
  self.log_request_metadata = self.get_log_request_metadata()
1339
1372
 
1373
+ async def freeze_gc(self):
1374
+ """Send a freeze_gc message to the scheduler first, then freeze locally."""
1375
+ self.send_to_scheduler.send_pyobj(FreezeGCReq())
1376
+ freeze_gc("Tokenizer Manager")
1377
+ return None
1378
+
1340
1379
  def create_abort_task(self, obj: GenerateReqInput):
1341
1380
  # Abort the request if the client is disconnected.
1342
1381
  async def abort_request():
@@ -92,6 +92,7 @@ class TpModelWorker:
92
92
  pp_rank=pp_rank,
93
93
  pp_size=server_args.pp_size,
94
94
  nccl_port=nccl_port,
95
+ dp_rank=dp_rank,
95
96
  server_args=server_args,
96
97
  is_draft_worker=is_draft_worker,
97
98
  req_to_token_pool=req_to_token_pool,
@@ -1,9 +1,16 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import multiprocessing as mp
3
5
  from http import HTTPStatus
4
- from typing import Dict, List, Optional
6
+ from typing import TYPE_CHECKING, Dict, List, Optional
5
7
 
8
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
6
9
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
10
+ from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.managers.scheduler import GenerationBatchResult
7
14
 
8
15
  logger = logging.getLogger(__name__)
9
16
 
@@ -41,6 +48,57 @@ def validate_input_length(
41
48
  return None
42
49
 
43
50
 
51
+ def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
52
+
53
+ logits_output = result.logits_output
54
+ assert logits_output is not None
55
+
56
+ return {
57
+ "extend_input_len_per_req": result.extend_input_len_per_req,
58
+ "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
59
+ "next_token_logprobs": result.logits_output.next_token_logprobs,
60
+ "next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
61
+ "next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
62
+ "next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
63
+ "next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
64
+ "input_token_logprobs": result.logits_output.input_token_logprobs,
65
+ "input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
66
+ "input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
67
+ "input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
68
+ "input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
69
+ }
70
+
71
+
72
+ def get_logprob_from_pp_outputs(
73
+ next_pp_outputs: PPProxyTensors,
74
+ ) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
75
+ logits_output = LogitsProcessorOutput(
76
+ # Do not send logits and hidden states because they are large
77
+ next_token_logits=None,
78
+ hidden_states=None,
79
+ next_token_logprobs=next_pp_outputs["next_token_logprobs"],
80
+ next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
81
+ next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
82
+ next_token_token_ids_logprobs_val=next_pp_outputs[
83
+ "next_token_token_ids_logprobs_val"
84
+ ],
85
+ next_token_token_ids_logprobs_idx=next_pp_outputs[
86
+ "next_token_token_ids_logprobs_idx"
87
+ ],
88
+ input_token_logprobs=next_pp_outputs["input_token_logprobs"],
89
+ input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
90
+ input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
91
+ input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
92
+ input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
93
+ )
94
+ extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
95
+ extend_logprob_start_len_per_req = next_pp_outputs[
96
+ "extend_logprob_start_len_per_req"
97
+ ]
98
+
99
+ return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
100
+
101
+
44
102
  class DPBalanceMeta:
45
103
  """
46
104
  This class will be use in scheduler and dp controller
@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
434
434
  device: str,
435
435
  kvcache: KVCache,
436
436
  need_sort: bool,
437
- max_num_extend_tokens: int,
438
437
  ):
439
438
  super().__init__(size, page_size, dtype, device, kvcache, need_sort)
440
439
  self.num_pages = size // page_size
441
- self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
442
- max_num_extend_tokens
443
- )
444
440
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
445
441
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
442
+ self.seen_max_num_extend_tokens_next_power_of_2 = 1
446
443
  self.clear()
447
444
 
448
445
  def alloc(self, need_size: int):
@@ -480,6 +477,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
480
477
  (last_loc + 1) % self.page_size == prefix_lens % self.page_size
481
478
  )
482
479
 
480
+ self.seen_max_num_extend_tokens_next_power_of_2 = max(
481
+ self.seen_max_num_extend_tokens_next_power_of_2,
482
+ next_power_of_2(extend_num_tokens),
483
+ )
484
+
483
485
  bs = len(prefix_lens)
484
486
  if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
485
487
  self.free_pages
@@ -498,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
498
500
  self.ret_values,
499
501
  next_power_of_2(bs),
500
502
  self.page_size,
501
- self.max_num_extend_tokens_next_power_of_2,
503
+ self.seen_max_num_extend_tokens_next_power_of_2,
502
504
  )
503
505
 
504
506
  if self.debug_mode:
@@ -66,17 +66,6 @@ def alloc_extend_kernel_ascend(
66
66
 
67
67
  class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
68
68
 
69
- def __init__(
70
- self,
71
- size: int,
72
- page_size: int,
73
- dtype: torch.dtype,
74
- device: str,
75
- kvcache: KVCache,
76
- need_sort: bool,
77
- ):
78
- super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
79
-
80
69
  def alloc_extend(
81
70
  self,
82
71
  prefix_lens: torch.Tensor,
@@ -13,6 +13,11 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_rank,
14
14
  get_tensor_model_parallel_world_size,
15
15
  )
16
+ from sglang.srt.layers.dp_attention import (
17
+ get_attention_tp_rank,
18
+ get_attention_tp_size,
19
+ is_dp_attention_enabled,
20
+ )
16
21
 
17
22
 
18
23
  def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
@@ -101,11 +106,16 @@ class HiCacheStorage(ABC):
101
106
 
102
107
  class HiCacheFile(HiCacheStorage):
103
108
 
104
- def __init__(self, file_path: str = "/tmp/hicache"):
109
+ def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
105
110
  self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
106
- tp_rank = get_tensor_model_parallel_rank()
107
- tp_size = get_tensor_model_parallel_world_size()
108
- self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
111
+ if is_dp_attention_enabled():
112
+ tp_rank = get_attention_tp_rank()
113
+ tp_size = get_attention_tp_size()
114
+ else:
115
+ tp_rank = get_tensor_model_parallel_rank()
116
+ tp_size = get_tensor_model_parallel_world_size()
117
+
118
+ self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
109
119
  if not os.path.exists(self.file_path) and tp_rank == 0:
110
120
  os.makedirs(self.file_path)
111
121
  logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
@@ -849,7 +849,7 @@ class MLATokenToKVPool(KVCache):
849
849
  cache_k_rope = cache_k_rope.view(self.store_dtype)
850
850
 
851
851
  set_mla_kv_buffer_triton(
852
- self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
852
+ self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
853
853
  )
854
854
 
855
855
  def get_cpu_copy(self, indices):
@@ -951,7 +951,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
951
951
  cache_k = cache_k.to(self.dtype)
952
952
 
953
953
  if self.store_dtype != self.dtype:
954
- cache_k = cache_k.view(store_dtype)
954
+ cache_k = cache_k.view(self.store_dtype)
955
955
 
956
956
  import torch_npu
957
957
 
@@ -1070,7 +1070,7 @@ def copy_all_layer_kv_cache(
1070
1070
  num_loop = tl.cdiv(stride, BLOCK_SIZE)
1071
1071
  for i in range(num_loop):
1072
1072
  copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1073
- mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
1073
+ mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
1074
1074
  value = tl.load(
1075
1075
  data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1076
1076
  )
@@ -7,6 +7,7 @@ from functools import wraps
7
7
  import psutil
8
8
  import torch
9
9
 
10
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
10
11
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
12
  from sglang.srt.utils import is_npu
12
13
 
@@ -307,6 +308,9 @@ class MHATokenToKVPoolHost(HostKVCache):
307
308
 
308
309
  return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
309
310
 
311
+ def get_ksize_per_token(self):
312
+ return self.get_size_per_token() // 2
313
+
310
314
  def init_kv_buffer(self):
311
315
  if self.layout == "layer_first":
312
316
  dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
@@ -484,8 +488,8 @@ class MHATokenToKVPoolHost(HostKVCache):
484
488
  ptr_list.append(k_ptr)
485
489
  ptr_list.append(v_ptr)
486
490
  key_ = keys[index // self.page_size]
487
- key_list.append(f"{key_}_k")
488
- key_list.append(f"{key_}_v")
491
+ key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
492
+ key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
489
493
  element_size = (
490
494
  self.layer_num
491
495
  * self.dtype.itemsize
@@ -496,6 +500,21 @@ class MHATokenToKVPoolHost(HostKVCache):
496
500
  element_size_list = [element_size] * len(key_list)
497
501
  return key_list, ptr_list, element_size_list
498
502
 
503
+ def get_buffer_with_hash(self, keys, indices):
504
+ assert self.layout == "page_first"
505
+ assert len(keys) == (len(indices) // self.page_size)
506
+
507
+ key_list = []
508
+ buf_list = []
509
+
510
+ for key, i in zip(keys, range(0, len(indices), self.page_size)):
511
+ key_list.append(f"{key}-k")
512
+ buf_list.append(self.k_buffer[i : i + self.page_size])
513
+ key_list.append(f"{key}-v")
514
+ buf_list.append(self.v_buffer[i : i + self.page_size])
515
+
516
+ return key_list, buf_list
517
+
499
518
 
500
519
  class MLATokenToKVPoolHost(HostKVCache):
501
520
  device_pool: MLATokenToKVPool
@@ -538,6 +557,9 @@ class MLATokenToKVPoolHost(HostKVCache):
538
557
  * self.layer_num
539
558
  )
540
559
 
560
+ def get_ksize_per_token(self):
561
+ return self.get_size_per_token()
562
+
541
563
  def init_kv_buffer(self):
542
564
  if self.layout == "layer_first":
543
565
  dims = (
@@ -704,3 +726,14 @@ class MLATokenToKVPoolHost(HostKVCache):
704
726
  )
705
727
  element_size_list = [element_size] * len(key_list)
706
728
  return key_list, ptr_list, element_size_list
729
+
730
+ def get_buffer_with_hash(self, keys, indices):
731
+ assert self.layout == "page_first"
732
+ assert len(keys) == (len(indices) // self.page_size)
733
+
734
+ buf_list = []
735
+
736
+ for i in range(0, len(indices), self.page_size):
737
+ buf_list.append(self.kv_buffer[i : i + self.page_size])
738
+
739
+ return keys, buf_list
@@ -7,10 +7,15 @@ import signal
7
7
  import threading
8
8
  from abc import ABC, abstractmethod
9
9
  from functools import wraps
10
- from typing import List, Optional, Tuple
10
+ from typing import Any, List, Optional, Tuple
11
11
 
12
12
  import torch
13
13
 
14
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
15
+ from sglang.srt.layers.dp_attention import (
16
+ get_attention_tp_rank,
17
+ is_dp_attention_enabled,
18
+ )
14
19
  from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
15
20
  from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
16
21
 
@@ -167,13 +172,20 @@ class HiCacheHF3FS(HiCacheStorage):
167
172
 
168
173
  @staticmethod
169
174
  def from_env_config(
170
- rank: int, bytes_per_page: int, dtype: torch.dtype
175
+ bytes_per_page: int, dtype: torch.dtype, rank: int = None
171
176
  ) -> "HiCacheHF3FS":
172
177
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
173
178
  Hf3fsGlobalMetadataClient,
174
179
  Hf3fsLocalMetadataClient,
175
180
  )
176
181
 
182
+ if rank is None:
183
+ rank = (
184
+ get_attention_tp_rank()
185
+ if is_dp_attention_enabled()
186
+ else get_tensor_model_parallel_rank()
187
+ )
188
+
177
189
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
178
190
  if not config_path:
179
191
  return HiCacheHF3FS(
@@ -228,15 +240,23 @@ class HiCacheHF3FS(HiCacheStorage):
228
240
  )
229
241
 
230
242
  def get(
231
- self, key: str, target_location: Optional[torch.Tensor] = None
243
+ self,
244
+ key: str,
245
+ target_location: Optional[Any] = None,
246
+ target_sizes: Optional[Any] = None,
232
247
  ) -> torch.Tensor | None:
233
- return self.batch_get([key], [target_location] if target_location else None)[0]
248
+ return self.batch_get(
249
+ [key],
250
+ [target_location] if target_location is not None else None,
251
+ [target_sizes] if target_sizes is not None else None,
252
+ )[0]
234
253
 
235
254
  @synchronized()
236
255
  def batch_get(
237
256
  self,
238
257
  keys: List[str],
239
- target_locations: Optional[List[torch.Tensor]] = None,
258
+ target_locations: Optional[Any] = None,
259
+ target_sizes: Optional[Any] = None,
240
260
  ) -> List[torch.Tensor | None]:
241
261
  page_indices = self.metadata_client.get_page_indices(self.rank, keys)
242
262
 
@@ -246,9 +266,15 @@ class HiCacheHF3FS(HiCacheStorage):
246
266
  batch_indices.append(i)
247
267
  file_offsets.append(page_index * self.bytes_per_page)
248
268
 
249
- file_results = [
250
- torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
251
- ]
269
+ if target_locations is not None:
270
+ for target_location in target_locations:
271
+ assert target_location.is_contiguous()
272
+ file_results = target_locations
273
+ else:
274
+ file_results = [
275
+ torch.empty(self.numel, dtype=self.dtype)
276
+ for _ in range(len(batch_indices))
277
+ ]
252
278
 
253
279
  futures = [
254
280
  self.executor.submit(
@@ -273,10 +299,27 @@ class HiCacheHF3FS(HiCacheStorage):
273
299
 
274
300
  return results
275
301
 
276
- def set(self, key: str, value: torch.Tensor) -> bool:
277
- return self.batch_set([key], [value])
302
+ def set(
303
+ self,
304
+ key: str,
305
+ value: Optional[Any] = None,
306
+ target_location: Optional[Any] = None,
307
+ target_sizes: Optional[Any] = None,
308
+ ) -> bool:
309
+ return self.batch_set(
310
+ [key],
311
+ [value] if value is not None else None,
312
+ [target_location] if target_location is not None else None,
313
+ [target_sizes] if target_sizes is not None else None,
314
+ )
278
315
 
279
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
316
+ def batch_set(
317
+ self,
318
+ keys: List[str],
319
+ values: Optional[Any] = None,
320
+ target_locations: Optional[Any] = None,
321
+ target_sizes: Optional[Any] = None,
322
+ ) -> bool:
280
323
  # Todo: Add prefix block's hash key
281
324
  key_with_prefix = [(key, "") for key in keys]
282
325
  indices = self.metadata_client.reserve_and_allocate_page_indices(
@@ -292,7 +335,8 @@ class HiCacheHF3FS(HiCacheStorage):
292
335
 
293
336
  batch_indices.append(i)
294
337
  file_offsets.append(page_index * self.bytes_per_page)
295
- file_values.append(value.contiguous())
338
+ assert value.is_contiguous()
339
+ file_values.append(value)
296
340
 
297
341
  futures = [
298
342
  self.executor.submit(
@@ -19,14 +19,13 @@ logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
21
  def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
- local_rank = get_tensor_model_parallel_rank()
23
22
  prefix_str = ""
24
23
  if prior_hash:
25
24
  prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
26
25
  current_token_ids_bytes = np.array(token_ids).tobytes()
27
26
  current_hash_object = hashlib.sha256(current_token_ids_bytes)
28
27
  current_hash_hex = current_hash_object.hexdigest()
29
- return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
28
+ return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
30
29
 
31
30
 
32
31
  @dataclass
@@ -97,7 +96,7 @@ class MooncakeStoreConfig:
97
96
 
98
97
 
99
98
  class MooncakeStore(HiCacheStorage):
100
- def __init__(self):
99
+ def __init__(self, is_mla: bool = False):
101
100
  try:
102
101
  from mooncake.store import MooncakeDistributedStore
103
102
  except ImportError as e:
@@ -127,6 +126,7 @@ class MooncakeStore(HiCacheStorage):
127
126
  logger.info("Connect to Mooncake store successfully.")
128
127
  self.warmup()
129
128
  logger.info("Mooncake store warmup successfully.")
129
+ self.is_mla = is_mla
130
130
 
131
131
  except ValueError as e:
132
132
  logger.error("Configuration loading failed: %s", e)
@@ -223,11 +223,15 @@ class MooncakeStore(HiCacheStorage):
223
223
 
224
224
  def exists(self, keys) -> bool | dict:
225
225
  _keys = []
226
+ local_rank = get_tensor_model_parallel_rank()
226
227
  for key in keys:
227
228
  if key is None:
228
229
  return None
229
230
 
230
- _keys.append(f"{key}_k")
231
+ if self.is_mla:
232
+ _keys.append(f"{key}_k")
233
+ else:
234
+ _keys.append(f"{key}_{local_rank}_k")
231
235
  result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
232
236
  return result
233
237