sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import logging
3
3
  import os
4
4
  import time
5
5
  import uuid
6
- from typing import Dict, List, Optional, Tuple, Union
6
+ from typing import Any, Dict, List, Optional, Tuple, Union
7
7
 
8
8
  import torch
9
9
 
@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
28
28
 
29
29
  def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
30
30
  """Initialize NIXL storage connector."""
31
+ # Might be better to be unified across HiCache backends and moved to HiCacheController
32
+ file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
31
33
  self.file_manager = (
32
34
  NixlFileManager(file_path)
33
35
  if plugin not in NixlBackendSelection.OBJ_PLUGINS
@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
44
46
 
45
47
  self.registration = NixlRegistration(self.agent)
46
48
 
49
+ def register_buffers(
50
+ self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
51
+ ) -> Optional[Any]:
52
+ """Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
53
+ if isinstance(buffers[0], tuple):
54
+ tuples = [(x[0], x[1], 0, "") for x in buffers]
55
+ return self.registration._register_memory(tuples, "DRAM")
56
+ else:
57
+ return self.registration._register_memory(buffers)
58
+
59
+ def register_files(
60
+ self, file_paths: List[str], open_file: Optional[bool] = True
61
+ ) -> Optional[Any]:
62
+ """Register files with NIXL."""
63
+ tuples = self.file_manager.files_to_nixl_tuples(file_paths)
64
+ return self.registration._register_memory(tuples, "FILE")
65
+
66
+ def register_objects(
67
+ self, keys: List[str], sizes: Optional[List[int]] = None
68
+ ) -> Optional[Any]:
69
+ """Register objects with NIXL."""
70
+ if not keys:
71
+ return None
72
+ tuples = [(0, 0, key, "") for key in keys]
73
+ return self.registration._register_memory(tuples, "OBJ")
74
+
47
75
  def _execute_transfer(
48
- self, tensors: List[torch.Tensor], keys: List[str], direction: str
76
+ self,
77
+ buffers: Optional[List[torch.Tensor | tuple]],
78
+ keys: List[str],
79
+ direction: str,
49
80
  ) -> bool:
50
- if len(tensors) != len(keys):
51
- logger.error("Mismatch between number of tensors and files/objects")
81
+ if len(buffers) != len(keys):
82
+ logger.error("Mismatch between number of tensors/buffers and files/objects")
52
83
  return False
53
84
 
54
- if not self.registration.register_buffers(tensors):
55
- logger.error("Failed to register tensors")
56
- return False
57
-
58
- # Get transfer tuples based on backend type
59
- tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
85
+ # Registering file and object keys per transfer, to be updated when
86
+ # pre-registration for file and object is added to HiCache.
60
87
  if self.backend_selector.mem_type == "FILE":
61
- file_tuples = self.file_manager.files_to_nixl_tuples(keys)
62
- if not file_tuples or not self.registration.register_files(file_tuples):
88
+ tuples = self.file_manager.files_to_nixl_tuples(keys)
89
+ if not tuples or not self.registration._register_memory(tuples, "FILE"):
63
90
  logger.error("Failed to prepare files for transfer")
64
91
  return False
65
- transfer_tuples = [
66
- (x[0], s, x[2]) for x, s in zip(file_tuples, tensor_sizes)
67
- ]
68
- else:
69
- if not self.registration.register_objects(keys, tensors):
92
+ else: # mem_type == "OBJ"
93
+ tuples = [(0, 0, key, "") for key in keys]
94
+ if not tuples or not self.registration._register_memory(tuples, "OBJ"):
70
95
  logger.error("Failed to register objects")
71
96
  return False
72
- transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
73
97
 
98
+ # Prepare transfer descriptors
99
+ if isinstance(buffers[0], torch.Tensor):
100
+ tensor_sizes = [
101
+ tensor.element_size() * tensor.numel() for tensor in buffers
102
+ ]
103
+ storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
104
+ host_descs = self.agent.get_xfer_descs(buffers)
105
+ elif isinstance(buffers[0], tuple):
106
+ storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
107
+ host_descs = self.agent.get_xfer_descs(
108
+ [(x[0], x[1], 0) for x in buffers], "DRAM"
109
+ )
110
+ else:
111
+ return False
112
+
113
+ storage_descs = self.agent.get_xfer_descs(
114
+ storage_tuples, self.backend_selector.mem_type
115
+ )
116
+
117
+ if (host_descs is None) or (storage_descs is None):
118
+ logger.error("Failed to get transfer descriptors")
119
+ return False
120
+
121
+ # Initialize transfer, default assumption that tensor was registered
74
122
  try:
75
- # Get transfer descriptors
76
- if (tensor_descs := self.agent.get_xfer_descs(tensors)) is None or (
77
- file_descs := self.agent.get_xfer_descs(
78
- transfer_tuples, self.backend_selector.mem_type
79
- )
80
- ) is None:
81
- logger.error("Failed to get transfer descriptors")
123
+ xfer_req = self.agent.initialize_xfer(
124
+ direction, host_descs, storage_descs, self.agent_name
125
+ )
126
+ except Exception:
127
+ # Check if it was due to missing pre-registration
128
+ if not self.register_buffers(buffers):
129
+ logger.error("Failed to register tensors/buffers")
82
130
  return False
83
131
 
84
- # Initialize and execute transfer
85
- if (
86
- xfer_req := self.agent.initialize_xfer(
87
- direction, tensor_descs, file_descs, self.agent_name
132
+ try:
133
+ xfer_req = self.agent.initialize_xfer(
134
+ direction, host_descs, storage_descs, self.agent_name
88
135
  )
89
- ) is None:
90
- logger.error("Failed to create transfer request")
136
+ except Exception as e:
137
+ logger.error(f"Failed to create transfer request: {e}")
91
138
  return False
92
139
 
140
+ # Execute transfer and wait for its completion
141
+ try:
93
142
  state = self.agent.transfer(xfer_req)
94
143
  while state != "DONE":
95
144
  state = self.agent.check_xfer_state(xfer_req)
96
145
  if state == "ERR":
146
+ self.agent.release_xfer_handle(xfer_req)
97
147
  logger.error("Transfer failed")
98
148
  return False
99
- time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
149
+ time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
150
+
151
+ self.agent.release_xfer_handle(xfer_req)
100
152
  return True
101
153
 
102
154
  except Exception as e:
@@ -106,45 +158,87 @@ class HiCacheNixl(HiCacheStorage):
106
158
  logger.error(f"Traceback: {traceback.format_exc()}")
107
159
  return False
108
160
 
109
- def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
110
- if not keys:
111
- return True
112
-
113
- if self.backend_selector.mem_type == "FILE":
114
- file_paths = []
115
- for key in keys:
116
- tensor_path = self.file_manager.get_file_path(key)
117
- if not self.file_manager.create_file(tensor_path):
118
- logger.error(f"Failed to create file {tensor_path}")
119
- return False
120
- file_paths.append(tensor_path)
121
- return self._execute_transfer(values, file_paths, "WRITE")
122
- else:
123
- return self._execute_transfer(values, keys, "WRITE")
124
-
125
- def set(self, key: str, value: torch.Tensor) -> bool:
126
- return self.batch_set([key], [value])
127
-
128
161
  def get(
129
- self, key: str, dst_tensor: Optional[torch.Tensor] = None
162
+ self,
163
+ key: str,
164
+ target_location: Optional[torch.Tensor | int] = None,
165
+ target_sizes: Optional[int] = None,
130
166
  ) -> torch.Tensor | None:
131
- if dst_tensor is None: # To be removed, being compatible with the current API
167
+ # To be removed, being compatible with the current API
168
+ if target_location is None:
132
169
  return None
133
- result = self.batch_get([key], [dst_tensor])
170
+ if target_sizes:
171
+ result = self.batch_get([key], [target_location], [target_sizes])
172
+ else:
173
+ result = self.batch_get([key], [target_location])
134
174
  return result[0] if result else None
135
175
 
136
176
  def batch_get(
137
- self, keys: List[str], dst_tensors: List[torch.Tensor]
138
- ) -> List[Optional[torch.Tensor]]:
177
+ self,
178
+ keys: List[str],
179
+ target_locations: Optional[List[torch.Tensor | int]] = None,
180
+ target_sizes: Optional[List[int]] = None,
181
+ ) -> List[torch.Tensor | None]:
139
182
  if not keys:
140
183
  return []
141
184
 
185
+ # To be removed, being compatible with the current API
186
+ if not target_locations:
187
+ return [None] * len(keys)
188
+
189
+ if target_sizes and (len(target_sizes) != len(target_locations)):
190
+ logger.error("Mismatch between number of target_locations and target_sizes")
191
+ return [None] * len(keys)
192
+ if target_sizes:
193
+ dest = list(zip(target_locations, target_sizes))
194
+ else:
195
+ dest = target_locations
196
+
142
197
  if self.backend_selector.mem_type == "FILE":
143
198
  file_paths = [self.file_manager.get_file_path(key) for key in keys]
144
- success = self._execute_transfer(dst_tensors, file_paths, "READ")
199
+ success = self._execute_transfer(dest, file_paths, "READ")
145
200
  else:
146
- success = self._execute_transfer(dst_tensors, keys, "READ")
147
- return dst_tensors if success else [None] * len(keys)
201
+ success = self._execute_transfer(dest, keys, "READ")
202
+ return target_locations if success and not target_sizes else [None] * len(keys)
203
+
204
+ def set(
205
+ self,
206
+ key: str,
207
+ value: Optional[torch.Tensor] = None,
208
+ target_location: Optional[int] = None,
209
+ target_sizes: Optional[int] = None,
210
+ ) -> bool:
211
+ if target_location and target_sizes:
212
+ return self.batch_set([key], None, [target_location], [target_sizes])
213
+ else:
214
+ return self.batch_set([key], [value])
215
+
216
+ def batch_set(
217
+ self,
218
+ keys: List[str],
219
+ values: Optional[List[torch.Tensor]] = None,
220
+ target_locations: Optional[List[int]] = None,
221
+ target_sizes: Optional[List[int]] = None,
222
+ ) -> bool:
223
+ if not keys or (not values and (not target_locations or not target_sizes)):
224
+ logger.error("Keys or values were not passed")
225
+ return False
226
+
227
+ if not values:
228
+ values = list(zip(target_locations, target_sizes))
229
+
230
+ if self.backend_selector.mem_type == "FILE":
231
+ file_paths = []
232
+ for key in keys:
233
+ file_path = self.file_manager.get_file_path(key)
234
+ # New file per set, to be updated when partial writes is added to HiCache
235
+ if not self.file_manager.create_file(file_path):
236
+ logger.error(f"Failed to create file {file_path}")
237
+ return False
238
+ file_paths.append(file_path)
239
+ return self._execute_transfer(values, file_paths, "WRITE")
240
+ else: # mem_type == "OBJ"
241
+ return self._execute_transfer(values, keys, "WRITE")
148
242
 
149
243
  def exists(self, key: str) -> bool:
150
244
  tuples = self.registration.create_query_tuples(
@@ -109,66 +109,35 @@ class NixlRegistration:
109
109
  return [(0, 0, key)]
110
110
 
111
111
  def _register_memory(
112
- self, items: Union[List[tuple], List[torch.Tensor]], mem_type: str, desc: str
112
+ self,
113
+ items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
114
+ mem_type: Optional[str] = None,
113
115
  ) -> Optional[Any]:
114
116
  """Common registration logic for files, objects, and buffers.
115
117
  Args:
116
118
  items: List of tuples or tensors to register
117
- mem_type: Memory type ("FILE", "OBJ", "DRAM", "VRAM")
118
- desc: Description for logging
119
+ mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
119
120
  """
120
- try:
121
- if not items:
122
- return None
123
-
124
- reg_descs = self.agent.get_reg_descs(items, mem_type)
125
- if reg_descs is None:
126
- logger.error("Failed to create registration descriptors")
127
- return None
128
-
129
- registered_memory = self.agent.register_memory(reg_descs)
130
- if registered_memory:
131
- return registered_memory
132
- else:
133
- logger.error("Failed to register with NIXL")
134
- return None
135
-
136
- except Exception as e:
137
- logger.error(f"Failed to register {desc}: {e}")
121
+ if isinstance(items, list) and not items:
138
122
  return None
139
123
 
140
- def register_buffers(
141
- self, buffers: Union[torch.Tensor, List[torch.Tensor]]
142
- ) -> Optional[Any]:
143
- """Register tensors/buffers with NIXL."""
144
- if isinstance(buffers, torch.Tensor):
145
- buffers = [buffers]
146
-
147
- if not buffers:
124
+ reg_descs = self.agent.get_reg_descs(items, mem_type)
125
+ if reg_descs is None:
126
+ logger.error("Failed to create registration descriptors")
148
127
  return None
149
128
 
150
- # Determine memory type based on tensor device
151
- mem_type = "VRAM" if buffers[0].device.type == "cuda" else "DRAM"
152
- return self._register_memory(buffers, mem_type, "buffers")
153
-
154
- def register_files(self, tuples: List[tuple]) -> Optional[Any]:
155
- """Register files with NIXL using (0, 0, fd, file_path) tuples."""
156
- return self._register_memory(tuples, "FILE", "files")
157
-
158
- def register_objects(
159
- self, keys: List[str], tensors: Optional[List[torch.Tensor]] = None
160
- ) -> Optional[Any]:
161
- """Register objects with NIXL."""
162
- if not keys:
129
+ try:
130
+ registered_memory = self.agent.register_memory(reg_descs)
131
+ return registered_memory # Could be None in case of error
132
+ except Exception as e:
133
+ if not mem_type:
134
+ logger.error(f"Failed to register Tensors with NIXL: {e}")
135
+ else:
136
+ logger.error(
137
+ f"Failed to register memory of type {mem_type} with NIXL: {e}"
138
+ )
163
139
  return None
164
140
 
165
- # Create object tuples with proper sizes
166
- tuples = [
167
- (0, tensor.element_size() * tensor.numel() if tensor else 0, key)
168
- for key, tensor in zip(keys, tensors or [None] * len(keys))
169
- ]
170
- return self._register_memory(tuples, "OBJ", "objects")
171
-
172
141
 
173
142
  class NixlFileManager:
174
143
  """Handles file system operations for NIXL."""
@@ -221,12 +190,9 @@ class NixlFileManager:
221
190
  return False
222
191
 
223
192
  def files_to_nixl_tuples(
224
- self, file_paths: List[str], open_file: bool = True
193
+ self, file_paths: List[str]
225
194
  ) -> List[Tuple[int, int, int, str]]:
226
195
  """Create NIXL tuples (offset, length, fd, file_path) for given files."""
227
- if not open_file:
228
- return [(0, 0, 0, path) for path in file_paths]
229
-
230
196
  tuples = []
231
197
  for path in file_paths:
232
198
  if (fd := self.open_file(path)) is None:
@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
7
7
 
8
8
  import torch
9
9
 
10
- from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
11
- from sglang.srt.mem_cache.nixl.nixl_utils import NixlFileManager, NixlRegistration
10
+ from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
11
+ from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
12
+ NixlFileManager,
13
+ NixlRegistration,
14
+ )
12
15
 
13
16
 
14
17
  class TestNixlUnified(unittest.TestCase):
@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
88
91
 
89
92
  # Test get
90
93
  retrieved = self.hicache.get(key, dst_tensor)
94
+ self.verify_tensors_equal(value, dst_tensor)
91
95
  self.verify_tensors_equal(value, retrieved)
92
96
 
97
+ # Same test in addr,len mode with another key and dst_tensor
98
+ key2 = "test_key2"
99
+ dst_tensor2 = torch.zeros_like(value, device="cpu")
100
+ src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
101
+ dst_addr, dst_len = (
102
+ dst_tensor2.data_ptr(),
103
+ dst_tensor2.numel() * dst_tensor2.element_size(),
104
+ )
105
+
106
+ # Test set
107
+ self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
108
+ self.assertTrue(self.hicache.exists(key))
109
+
110
+ # Test get
111
+ retrieved2 = self.hicache.get(key, dst_addr, dst_len)
112
+ self.assertTrue(retrieved2 == None)
113
+ self.verify_tensors_equal(value, dst_tensor2)
114
+
93
115
  def test_batch_set_get(self):
94
116
  """Test batch tensor set/get operations."""
95
117
  keys = ["key1", "key2", "key3"]
@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
108
130
  retrieved = self.hicache.batch_get(keys, dst_tensors)
109
131
  self.verify_tensor_lists_equal(values, retrieved)
110
132
 
133
+ # Same test in addr,len mode with another key and dst_tensor
134
+ keys2 = ["key4", "key5", "key6"]
135
+ dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
136
+ src_addrs = [v.data_ptr() for v in values]
137
+ src_lens = [v.numel() * v.element_size() for v in values]
138
+ dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
139
+ dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
140
+
141
+ # Test batch set
142
+ self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
143
+ self.assertTrue(all(self.hicache.exists(key) for key in keys2))
144
+
145
+ # Test batch get
146
+ retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
147
+ self.assertTrue(all(ret == None for ret in retrieved2))
148
+ self.verify_tensor_lists_equal(values, dst_tensors2)
149
+
111
150
  def test_mixed_operations(self):
112
151
  """Test mixing single and batch operations."""
113
152
  # Test interleaved set/get operations
@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
170
209
  self.file_manager.create_file(test_file)
171
210
 
172
211
  # Test tuple creation
173
- tuples = self.file_manager.files_to_nixl_tuples([test_file], False)
212
+ tuples = self.file_manager.files_to_nixl_tuples([test_file])
174
213
  self.assertIsNotNone(tuples)
175
214
  self.assertTrue(len(tuples) > 0)
176
215
 
@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
190
229
  tensor = torch.randn(10, 10)
191
230
 
192
231
  # Test buffer registration
193
- self.assertIsNotNone(self.registration.register_buffers(tensor))
232
+ self.assertIsNotNone(self.hicache.register_buffers(tensor))
194
233
 
195
234
  # Test batch registration
196
235
  tensors = [torch.randn(5, 5) for _ in range(3)]
197
- self.assertIsNotNone(self.registration.register_buffers(tensors))
236
+ self.assertIsNotNone(self.hicache.register_buffers(tensors))
198
237
 
199
238
  def test_register_files_with_tuples(self):
200
239
  """Test registration of files using NIXL tuples."""
@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
203
242
  self.file_manager.create_file(file)
204
243
 
205
244
  # Create tuples and register
206
- tuples = self.file_manager.files_to_nixl_tuples(files, False)
207
- self.registration.register_files(tuples)
245
+ tuples = self.file_manager.files_to_nixl_tuples(files)
246
+ self.hicache.register_files(tuples)
208
247
 
209
248
  # Verify tuples
210
249
  self.assertEqual(len(tuples), len(files))
@@ -240,6 +240,8 @@ class CudaGraphRunner:
240
240
  def __init__(self, model_runner: ModelRunner):
241
241
  # Parse args
242
242
  self.model_runner = model_runner
243
+ self.device = model_runner.device
244
+ self.device_module = torch.get_device_module(self.device)
243
245
  self.graphs = {}
244
246
  self.output_buffers = {}
245
247
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -305,13 +307,15 @@ class CudaGraphRunner:
305
307
  self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
306
308
 
307
309
  # Graph inputs
308
- with torch.device("cuda"):
310
+ with torch.device(self.device):
309
311
  self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
310
312
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
311
313
  self.seq_lens = torch.full(
312
314
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
313
315
  )
314
- self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
316
+ self.out_cache_loc = torch.zeros(
317
+ (self.max_num_token,), dtype=self._cache_loc_dtype()
318
+ )
315
319
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
316
320
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
317
321
  self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
@@ -366,12 +370,12 @@ class CudaGraphRunner:
366
370
  * self.num_tokens_per_bs
367
371
  ),
368
372
  dtype=torch.bool,
369
- device="cuda",
373
+ device=self.device,
370
374
  )
371
375
  self.next_token_logits_buffer = torch.zeros(
372
376
  (self.max_num_token, self.model_runner.model_config.vocab_size),
373
377
  dtype=torch.float,
374
- device="cuda",
378
+ device=self.device,
375
379
  )
376
380
 
377
381
  # Capture
@@ -383,6 +387,9 @@ class CudaGraphRunner:
383
387
  f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
384
388
  )
385
389
 
390
+ def _cache_loc_dtype(self):
391
+ return torch.int64
392
+
386
393
  def can_run(self, forward_batch: ForwardBatch):
387
394
  if self.require_mlp_tp_gather:
388
395
  cuda_graph_bs = (
@@ -502,8 +509,16 @@ class CudaGraphRunner:
502
509
  )
503
510
  logger.info(log_message)
504
511
 
512
+ def _capture_graph(self, graph, pool, stream, run_once_fn):
513
+ with self.device_module.graph(graph, pool=pool, stream=stream):
514
+ out = run_once_fn()
515
+ return out
516
+
517
+ def _create_device_graph(self):
518
+ return torch.cuda.CUDAGraph()
519
+
505
520
  def capture_one_batch_size(self, bs: int, forward: Callable):
506
- graph = torch.cuda.CUDAGraph()
521
+ graph = self._create_device_graph()
507
522
  stream = self.stream
508
523
  num_tokens = bs * self.num_tokens_per_bs
509
524
 
@@ -643,19 +658,17 @@ class CudaGraphRunner:
643
658
  return logits_output_or_pp_proxy_tensors
644
659
 
645
660
  for _ in range(2):
646
- torch.cuda.synchronize()
661
+ self.device_module.synchronize()
647
662
  self.model_runner.tp_group.barrier()
648
-
649
663
  run_once()
650
664
 
651
665
  if get_global_graph_memory_pool() is None:
652
- set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
666
+ set_global_graph_memory_pool(self.device_module.graph_pool_handle())
653
667
  # Set graph pool id globally to be able to use symmetric memory
654
668
  set_graph_pool_id(get_global_graph_memory_pool())
655
- with torch.cuda.graph(
656
- graph, pool=get_global_graph_memory_pool(), stream=stream
657
- ):
658
- out = run_once()
669
+ out = self._capture_graph(
670
+ graph, get_global_graph_memory_pool(), stream, run_once
671
+ )
659
672
 
660
673
  return graph, out
661
674
 
@@ -241,6 +241,9 @@ class ForwardBatch:
241
241
  prefix_chunk_num_tokens: Optional[List[int]] = None
242
242
  # KV Indices for each chunk
243
243
  prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
244
+ # For MLA chunked prefix cache used in chunked prefill
245
+ # Tell attention backend whether lse needs to be returned
246
+ mha_return_lse: Optional[bool] = None
244
247
 
245
248
  # For multimodal
246
249
  mm_inputs: Optional[List[MultimodalInputs]] = None
@@ -649,7 +652,7 @@ class ForwardBatch:
649
652
  num_tokens = global_num_tokens[0]
650
653
 
651
654
  self.global_dp_buffer_len = buffer_len
652
- set_dp_buffer_len(buffer_len, num_tokens)
655
+ set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
653
656
 
654
657
  bs = self.batch_size
655
658