sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import logging
3
3
  import os
4
4
  import time
5
5
  import uuid
6
- from typing import Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
 
@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
28
28
 
29
29
  def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
30
30
  """Initialize NIXL storage connector."""
31
+ # Might be better to be unified across HiCache backends and moved to HiCacheController
32
+ file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
31
33
  self.file_manager = (
32
34
  NixlFileManager(file_path)
33
35
  if plugin not in NixlBackendSelection.OBJ_PLUGINS
@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
44
46
 
45
47
  self.registration = NixlRegistration(self.agent)
46
48
 
49
+ def register_buffers(
50
+ self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
51
+ ) -> Optional[Any]:
52
+ """Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
53
+ if isinstance(buffers[0], tuple):
54
+ tuples = [(x[0], x[1], 0, "") for x in buffers]
55
+ return self.registration._register_memory(tuples, "DRAM")
56
+ else:
57
+ return self.registration._register_memory(buffers)
58
+
59
+ def register_files(
60
+ self, file_paths: List[str], open_file: Optional[bool] = True
61
+ ) -> Optional[Any]:
62
+ """Register files with NIXL."""
63
+ tuples = self.file_manager.files_to_nixl_tuples(file_paths)
64
+ return self.registration._register_memory(tuples, "FILE")
65
+
66
+ def register_objects(
67
+ self, keys: List[str], sizes: Optional[List[int]] = None
68
+ ) -> Optional[Any]:
69
+ """Register objects with NIXL."""
70
+ if not keys:
71
+ return None
72
+ tuples = [(0, 0, key, "") for key in keys]
73
+ return self.registration._register_memory(tuples, "OBJ")
74
+
47
75
  def _execute_transfer(
48
- self, tensors: List[torch.Tensor], keys: List[str], direction: str
76
+ self,
77
+ buffers: Optional[List[torch.Tensor | tuple]],
78
+ keys: List[str],
79
+ direction: str,
49
80
  ) -> bool:
50
- if len(tensors) != len(keys):
51
- logger.error("Mismatch between number of tensors and files/objects")
81
+ if len(buffers) != len(keys):
82
+ logger.error("Mismatch between number of tensors/buffers and files/objects")
52
83
  return False
53
84
 
54
- if not self.registration.register_buffers(tensors):
55
- logger.error("Failed to register tensors")
56
- return False
57
-
58
- # Get transfer tuples based on backend type
59
- tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
85
+ # Registering file and object keys per transfer, to be updated when
86
+ # pre-registration for file and object is added to HiCache.
60
87
  if self.backend_selector.mem_type == "FILE":
61
- file_tuples = self.file_manager.files_to_nixl_tuples(keys)
62
- if not file_tuples or not self.registration.register_files(file_tuples):
88
+ tuples = self.file_manager.files_to_nixl_tuples(keys)
89
+ if not tuples or not self.registration._register_memory(tuples, "FILE"):
63
90
  logger.error("Failed to prepare files for transfer")
64
91
  return False
65
- transfer_tuples = [
66
- (x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
67
- ]
68
- else:
69
- if not self.registration.register_objects(keys, tensors):
92
+ else: # mem_type == "OBJ"
93
+ tuples = [(0, 0, key, "") for key in keys]
94
+ if not tuples or not self.registration._register_memory(tuples, "OBJ"):
70
95
  logger.error("Failed to register objects")
71
96
  return False
72
- transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
73
97
 
98
+ # Prepare transfer descriptors
99
+ if isinstance(buffers[0], torch.Tensor):
100
+ tensor_sizes = [
101
+ tensor.element_size() * tensor.numel() for tensor in buffers
102
+ ]
103
+ storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
104
+ host_descs = self.agent.get_xfer_descs(buffers)
105
+ elif isinstance(buffers[0], tuple):
106
+ storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
107
+ host_descs = self.agent.get_xfer_descs(
108
+ [(x[0], x[1], 0) for x in buffers], "DRAM"
109
+ )
110
+ else:
111
+ return False
112
+
113
+ storage_descs = self.agent.get_xfer_descs(
114
+ storage_tuples, self.backend_selector.mem_type
115
+ )
116
+
117
+ if (host_descs is None) or (storage_descs is None):
118
+ logger.error("Failed to get transfer descriptors")
119
+ return False
120
+
121
+ # Initialize transfer, default assumption that tensor was registered
74
122
  try:
75
- # Get transfer descriptors
76
- if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
77
- file_descs := self.agent.get_xfer_descs(
78
- transfer_tuples, self.backend_selector.mem_type
79
- )
80
- ) is None:
81
- logger.error("Failed to get transfer descriptors")
123
+ xfer_req = self.agent.initialize_xfer(
124
+ direction, host_descs, storage_descs, self.agent_name
125
+ )
126
+ except Exception:
127
+ # Check if it was due to missing pre-registration
128
+ if not self.register_buffers(buffers):
129
+ logger.error("Failed to register tensors/buffers")
82
130
  return False
83
131
 
84
- # Initialize and execute transfer
85
- if (
86
- xfer_req := self.agent.initialize_xfer(
87
- direction, tensor_descs, file_descs, self.agent_name
132
+ try:
133
+ xfer_req = self.agent.initialize_xfer(
134
+ direction, host_descs, storage_descs, self.agent_name
88
135
  )
89
- ) is None:
90
- logger.error("Failed to create transfer request")
136
+ except Exception as e:
137
+ logger.error(f"Failed to create transfer request: {e}")
91
138
  return False
92
139
 
140
+ # Execute transfer and wait for its completion
141
+ try:
93
142
  state = self.agent.transfer(xfer_req)
94
143
  while state != "DONE":
95
144
  state = self.agent.check_xfer_state(xfer_req)
96
145
  if state == "ERR":
146
+ self.agent.release_xfer_handle(xfer_req)
97
147
  logger.error("Transfer failed")
98
148
  return False
99
- time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
149
+ time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
150
+
151
+ self.agent.release_xfer_handle(xfer_req)
100
152
  return True
101
153
 
102
154
  except Exception as e:
@@ -106,45 +158,87 @@ class HiCacheNixl(HiCacheStorage):
106
158
  logger.error(f"Traceback: {traceback.format_exc()}")
107
159
  return False
108
160
 
109
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
110
- if not keys:
111
- return True
112
-
113
- if self.backend_selector.mem_type == "FILE":
114
- file_paths = []
115
- for key in keys:
116
- tensor_path = self.file_manager.get_file_path(key)
117
- if not self.file_manager.create_file(tensor_path):
118
- logger.error(f"Failed to create file {tensor_path}")
119
- return False
120
- file_paths.append(tensor_path)
121
- return self._execute_transfer(values, file_paths, "WRITE")
122
- else:
123
- return self._execute_transfer(values, keys, "WRITE")
124
-
125
- def set(self, key: str, value: torch.Tensor) -> bool:
126
- return self.batch_set([key], [value])
127
-
128
161
  def get(
129
- self, key: str, dst_tensor: Optional[torch.Tensor] = None
162
+ self,
163
+ key: str,
164
+ target_location: Optional[torch.Tensor | int] = None,
165
+ target_sizes: Optional[int] = None,
130
166
  ) -> torch.Tensor | None:
131
- if dst_tensor is None: # To be removed, being compatible with the current API
167
+ # To be removed, being compatible with the current API
168
+ if target_location is None:
132
169
  return None
133
- result = self.batch_get([key], [dst_tensor])
170
+ if target_sizes:
171
+ result = self.batch_get([key], [target_location], [target_sizes])
172
+ else:
173
+ result = self.batch_get([key], [target_location])
134
174
  return result[0] if result else None
135
175
 
136
176
  def batch_get(
137
- self, keys: List[str], dst_tensors: List[torch.Tensor]
138
- ) -> List[Optional[torch.Tensor]]:
177
+ self,
178
+ keys: List[str],
179
+ target_locations: Optional[List[torch.Tensor | int]] = None,
180
+ target_sizes: Optional[List[int]] = None,
181
+ ) -> List[torch.Tensor | None]:
139
182
  if not keys:
140
183
  return []
141
184
 
185
+ # To be removed, being compatible with the current API
186
+ if not target_locations:
187
+ return [None] * len(keys)
188
+
189
+ if target_sizes and (len(target_sizes) != len(target_locations)):
190
+ logger.error("Mismatch between number of target_locations and target_sizes")
191
+ return [None] * len(keys)
192
+ if target_sizes:
193
+ dest = list(zip(target_locations, target_sizes))
194
+ else:
195
+ dest = target_locations
196
+
142
197
  if self.backend_selector.mem_type == "FILE":
143
198
  file_paths = [self.file_manager.get_file_path(key) for key in keys]
144
- success = self._execute_transfer(dst_tensors, file_paths, "READ")
199
+ success = self._execute_transfer(dest, file_paths, "READ")
145
200
  else:
146
- success = self._execute_transfer(dst_tensors, keys, "READ")
147
- return dst_tensors if success else [None] * len(keys)
201
+ success = self._execute_transfer(dest, keys, "READ")
202
+ return target_locations if success and not target_sizes else [None] * len(keys)
203
+
204
+ def set(
205
+ self,
206
+ key: str,
207
+ value: Optional[torch.Tensor] = None,
208
+ target_location: Optional[int] = None,
209
+ target_sizes: Optional[int] = None,
210
+ ) -> bool:
211
+ if target_location and target_sizes:
212
+ return self.batch_set([key], None, [target_location], [target_sizes])
213
+ else:
214
+ return self.batch_set([key], [value])
215
+
216
+ def batch_set(
217
+ self,
218
+ keys: List[str],
219
+ values: Optional[List[torch.Tensor]] = None,
220
+ target_locations: Optional[List[int]] = None,
221
+ target_sizes: Optional[List[int]] = None,
222
+ ) -> bool:
223
+ if not keys or (not values and (not target_locations or not target_sizes)):
224
+ logger.error("Keys or values were not passed")
225
+ return False
226
+
227
+ if not values:
228
+ values = list(zip(target_locations, target_sizes))
229
+
230
+ if self.backend_selector.mem_type == "FILE":
231
+ file_paths = []
232
+ for key in keys:
233
+ file_path = self.file_manager.get_file_path(key)
234
+ # New file per set, to be updated when partial writes is added to HiCache
235
+ if not self.file_manager.create_file(file_path):
236
+ logger.error(f"Failed to create file {file_path}")
237
+ return False
238
+ file_paths.append(file_path)
239
+ return self._execute_transfer(values, file_paths, "WRITE")
240
+ else: # mem_type == "OBJ"
241
+ return self._execute_transfer(values, keys, "WRITE")
148
242
 
149
243
  def exists(self, key: str) -> bool:
150
244
  tuples = self.registration.create_query_tuples(
@@ -109,66 +109,35 @@ class NixlRegistration:
109
109
  return [(0, 0, key)]
110
110
 
111
111
  def _register_memory(
112
- self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
112
+ self,
113
+ items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
114
+ mem_type: Optional[str] = None,
113
115
  ) -> Optional[Any]:
114
116
  """Common registration logic for files, objects, and buffers.
115
117
  Args:
116
118
  items: List of tuples or tensors to register
117
- mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
118
- desc: Description for logging
119
+ mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
119
120
  """
120
- try:
121
- if not items:
122
- return None
123
-
124
- reg_descs = self.agent.get_reg_descs(items, mem_type)
125
- if reg_descs is None:
126
- logger.error("Failed to create registration descriptors")
127
- return None
128
-
129
- registered_memory = self.agent.register_memory(reg_descs)
130
- if registered_memory:
131
- return registered_memory
132
- else:
133
- logger.error("Failed to register with NIXL")
134
- return None
135
-
136
- except Exception as e:
137
- logger.error(f"Failed to register {desc}: {e}")
121
+ if isinstance(items, list) and not items:
138
122
  return None
139
123
 
140
- def register_buffers(
141
- self, buffers: Union[torch.Tensor, List[torch.Tensor]]
142
- ) -> Optional[Any]:
143
- """Register tensors/buffers with NIXL."""
144
- if isinstance(buffers, torch.Tensor):
145
- buffers = [buffers]
146
-
147
- if not buffers:
124
+ reg_descs = self.agent.get_reg_descs(items, mem_type)
125
+ if reg_descs is None:
126
+ logger.error("Failed to create registration descriptors")
148
127
  return None
149
128
 
150
- # Determine memory type based on tensor device
151
- mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
152
- return self._register_memory(buffers, mem_type, "buffers")
153
-
154
- def register_files(self, tuples: List[tuple]) -> Optional[Any]:
155
- """Register files with NIXL using (0, 0, fd, file_path) tuples."""
156
- return self._register_memory(tuples, "FILE", "files")
157
-
158
- def register_objects(
159
- self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
160
- ) -> Optional[Any]:
161
- """Register objects with NIXL."""
162
- if not keys:
129
+ try:
130
+ registered_memory = self.agent.register_memory(reg_descs)
131
+ return registered_memory # Could be None in case of error
132
+ except Exception as e:
133
+ if not mem_type:
134
+ logger.error(f"Failed to register Tensors with NIXL: {e}")
135
+ else:
136
+ logger.error(
137
+ f"Failed to register memory of type {mem_type} with NIXL: {e}"
138
+ )
163
139
  return None
164
140
 
165
- # Create object tuples with proper sizes
166
- tuples = [
167
- (0, tensor.element_size() * tensor.numel() if tensor else 0, key)
168
- for key, tensor in zip(keys, tensors or [None] * len(keys))
169
- ]
170
- return self._register_memory(tuples, "OBJ", "objects")
171
-
172
141
 
173
142
  class NixlFileManager:
174
143
  """Handles file system operations for NIXL."""
@@ -221,12 +190,9 @@ class NixlFileManager:
221
190
  return False
222
191
 
223
192
  def files_to_nixl_tuples(
224
- self, file_paths: List[str], open_file: bool = True
193
+ self, file_paths: List[str]
225
194
  ) -> List[Tuple[int, int, int, str]]:
226
195
  """Create NIXL tuples (offset, length, fd, file_path) for given files."""
227
- if not open_file:
228
- return [(0, 0, 0, path) for path in file_paths]
229
-
230
196
  tuples = []
231
197
  for path in file_paths:
232
198
  if (fd := self.open_file(path)) is None:
@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
7
7
 
8
8
  import torch
9
9
 
10
- from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
11
- from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
10
+ from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
11
+ from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
12
+ NixlFileManager,
13
+ NixlRegistration,
14
+ )
12
15
 
13
16
 
14
17
  class TestNixlUnified(unittest.TestCase):
@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
88
91
 
89
92
  # Test get
90
93
  retrieved = self.hicache.get(key, dst_tensor)
94
+ self.verify_tensors_equal(value, dst_tensor)
91
95
  self.verify_tensors_equal(value, retrieved)
92
96
 
97
+ # Same test in addr,len mode with another key and dst_tensor
98
+ key2 = "test_key2"
99
+ dst_tensor2 = torch.zeros_like(value, device="cpu")
100
+ src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
101
+ dst_addr, dst_len = (
102
+ dst_tensor2.data_ptr(),
103
+ dst_tensor2.numel() * dst_tensor2.element_size(),
104
+ )
105
+
106
+ # Test set
107
+ self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
108
+ self.assertTrue(self.hicache.exists(key))
109
+
110
+ # Test get
111
+ retrieved2 = self.hicache.get(key, dst_addr, dst_len)
112
+ self.assertTrue(retrieved2 == None)
113
+ self.verify_tensors_equal(value, dst_tensor2)
114
+
93
115
  def test_batch_set_get(self):
94
116
  """Test batch tensor set/get operations."""
95
117
  keys = ["key1", "key2", "key3"]
@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
108
130
  retrieved = self.hicache.batch_get(keys, dst_tensors)
109
131
  self.verify_tensor_lists_equal(values, retrieved)
110
132
 
133
+ # Same test in addr,len mode with another key and dst_tensor
134
+ keys2 = ["key4", "key5", "key6"]
135
+ dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
136
+ src_addrs = [v.data_ptr() for v in values]
137
+ src_lens = [v.numel() * v.element_size() for v in values]
138
+ dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
139
+ dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
140
+
141
+ # Test batch set
142
+ self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
143
+ self.assertTrue(all(self.hicache.exists(key) for key in keys2))
144
+
145
+ # Test batch get
146
+ retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
147
+ self.assertTrue(all(ret == None for ret in retrieved2))
148
+ self.verify_tensor_lists_equal(values, dst_tensors2)
149
+
111
150
  def test_mixed_operations(self):
112
151
  """Test mixing single and batch operations."""
113
152
  # Test interleaved set/get operations
@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
170
209
  self.file_manager.create_file(test_file)
171
210
 
172
211
  # Test tuple creation
173
- tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
212
+ tuples = self.file_manager.files_to_nixl_tuples([test_file])
174
213
  self.assertIsNotNone(tuples)
175
214
  self.assertTrue(len(tuples) > 0)
176
215
 
@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
190
229
  tensor = torch.randn(10, 10)
191
230
 
192
231
  # Test buffer registration
193
- self.assertIsNotNone(self.registration.register_buffers(tensor))
232
+ self.assertIsNotNone(self.hicache.register_buffers(tensor))
194
233
 
195
234
  # Test batch registration
196
235
  tensors = [torch.randn(5, 5) for _ in range(3)]
197
- self.assertIsNotNone(self.registration.register_buffers(tensors))
236
+ self.assertIsNotNone(self.hicache.register_buffers(tensors))
198
237
 
199
238
  def test_register_files_with_tuples(self):
200
239
  """Test registration of files using NIXL tuples."""
@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
203
242
  self.file_manager.create_file(file)
204
243
 
205
244
  # Create tuples and register
206
- tuples = self.file_manager.files_to_nixl_tuples(files, False)
207
- self.registration.register_files(tuples)
245
+ tuples = self.file_manager.files_to_nixl_tuples(files)
246
+ self.hicache.register_files(tuples)
208
247
 
209
248
  # Verify tuples
210
249
  self.assertEqual(len(tuples), len(files))
@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
34
34
  )
35
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
36
36
  from sglang.srt.layers.dp_attention import (
37
- DPPaddingMode,
37
+ DpPaddingMode,
38
38
  get_attention_tp_rank,
39
39
  get_attention_tp_size,
40
+ set_dp_buffer_len,
40
41
  )
41
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
42
43
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -239,6 +240,8 @@ class CudaGraphRunner:
239
240
  def __init__(self, model_runner: ModelRunner):
240
241
  # Parse args
241
242
  self.model_runner = model_runner
243
+ self.device = model_runner.device
244
+ self.device_module = torch.get_device_module(self.device)
242
245
  self.graphs = {}
243
246
  self.output_buffers = {}
244
247
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -304,13 +307,15 @@ class CudaGraphRunner:
304
307
  self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
305
308
 
306
309
  # Graph inputs
307
- with torch.device("cuda"):
310
+ with torch.device(self.device):
308
311
  self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
309
312
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
310
313
  self.seq_lens = torch.full(
311
314
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
312
315
  )
313
- self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
316
+ self.out_cache_loc = torch.zeros(
317
+ (self.max_num_token,), dtype=self._cache_loc_dtype()
318
+ )
314
319
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
315
320
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
316
321
  self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
@@ -349,30 +354,15 @@ class CudaGraphRunner:
349
354
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
350
355
  (self.dp_size,), dtype=torch.int32
351
356
  )
352
- self.gathered_buffer = torch.zeros(
353
- (
354
- self.max_num_token * self.dp_size,
355
- self.model_runner.model_config.hidden_size,
356
- ),
357
- dtype=self.model_runner.dtype,
358
- )
359
357
  else:
360
358
  assert self.require_attn_tp_gather
361
359
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
362
360
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
363
361
  (1,), dtype=torch.int32
364
362
  )
365
- self.gathered_buffer = torch.zeros(
366
- (
367
- self.max_num_token,
368
- self.model_runner.model_config.hidden_size,
369
- ),
370
- dtype=self.model_runner.dtype,
371
- )
372
363
  else:
373
364
  self.global_num_tokens_gpu = None
374
365
  self.global_num_tokens_for_logprob_gpu = None
375
- self.gathered_buffer = None
376
366
 
377
367
  self.custom_mask = torch.ones(
378
368
  (
@@ -380,12 +370,12 @@ class CudaGraphRunner:
380
370
  * self.num_tokens_per_bs
381
371
  ),
382
372
  dtype=torch.bool,
383
- device="cuda",
373
+ device=self.device,
384
374
  )
385
375
  self.next_token_logits_buffer = torch.zeros(
386
376
  (self.max_num_token, self.model_runner.model_config.vocab_size),
387
377
  dtype=torch.float,
388
- device="cuda",
378
+ device=self.device,
389
379
  )
390
380
 
391
381
  # Capture
@@ -397,6 +387,9 @@ class CudaGraphRunner:
397
387
  f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
398
388
  )
399
389
 
390
+ def _cache_loc_dtype(self):
391
+ return torch.int64
392
+
400
393
  def can_run(self, forward_batch: ForwardBatch):
401
394
  if self.require_mlp_tp_gather:
402
395
  cuda_graph_bs = (
@@ -516,8 +509,16 @@ class CudaGraphRunner:
516
509
  )
517
510
  logger.info(log_message)
518
511
 
512
+ def _capture_graph(self, graph, pool, stream, run_once_fn):
513
+ with self.device_module.graph(graph, pool=pool, stream=stream):
514
+ out = run_once_fn()
515
+ return out
516
+
517
+ def _create_device_graph(self):
518
+ return torch.cuda.CUDAGraph()
519
+
519
520
  def capture_one_batch_size(self, bs: int, forward: Callable):
520
- graph = torch.cuda.CUDAGraph()
521
+ graph = self._create_device_graph()
521
522
  stream = self.stream
522
523
  num_tokens = bs * self.num_tokens_per_bs
523
524
 
@@ -556,7 +557,7 @@ class CudaGraphRunner:
556
557
  device=input_ids.device,
557
558
  )
558
559
  )
559
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
560
+ global_dp_buffer_len = num_tokens * self.dp_size
560
561
  elif self.require_attn_tp_gather:
561
562
  self.global_num_tokens_gpu.copy_(
562
563
  torch.tensor(
@@ -572,9 +573,9 @@ class CudaGraphRunner:
572
573
  device=input_ids.device,
573
574
  )
574
575
  )
575
- gathered_buffer = self.gathered_buffer[:num_tokens]
576
+ global_dp_buffer_len = num_tokens
576
577
  else:
577
- gathered_buffer = None
578
+ global_dp_buffer_len = None
578
579
 
579
580
  spec_info = self.get_spec_info(num_tokens)
580
581
  if self.capture_hidden_mode != CaptureHiddenMode.FULL:
@@ -607,8 +608,8 @@ class CudaGraphRunner:
607
608
  positions=positions,
608
609
  global_num_tokens_gpu=self.global_num_tokens_gpu,
609
610
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
610
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
611
- gathered_buffer=gathered_buffer,
611
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
612
+ global_dp_buffer_len=global_dp_buffer_len,
612
613
  mrope_positions=mrope_positions,
613
614
  spec_algorithm=self.model_runner.spec_algorithm,
614
615
  spec_info=spec_info,
@@ -637,6 +638,7 @@ class CudaGraphRunner:
637
638
  def run_once():
638
639
  # Clean intermediate result cache for DP attention
639
640
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
641
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
640
642
 
641
643
  kwargs = {}
642
644
  if (
@@ -656,19 +658,17 @@ class CudaGraphRunner:
656
658
  return logits_output_or_pp_proxy_tensors
657
659
 
658
660
  for _ in range(2):
659
- torch.cuda.synchronize()
661
+ self.device_module.synchronize()
660
662
  self.model_runner.tp_group.barrier()
661
-
662
663
  run_once()
663
664
 
664
665
  if get_global_graph_memory_pool() is None:
665
- set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
666
+ set_global_graph_memory_pool(self.device_module.graph_pool_handle())
666
667
  # Set graph pool id globally to be able to use symmetric memory
667
668
  set_graph_pool_id(get_global_graph_memory_pool())
668
- with torch.cuda.graph(
669
- graph, pool=get_global_graph_memory_pool(), stream=stream
670
- ):
671
- out = run_once()
669
+ out = self._capture_graph(
670
+ graph, get_global_graph_memory_pool(), stream, run_once
671
+ )
672
672
 
673
673
  return graph, out
674
674