sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
13
13
  ROW_PARALLELISM_LINEAR_LORA_NAMES,
14
14
  LoRAType,
15
15
  get_hidden_dim,
16
- get_normalized_lora_weight_names,
16
+ get_normalized_target_modules,
17
17
  get_stacked_multiply,
18
- get_weight_name,
18
+ get_target_module_name,
19
19
  )
20
20
 
21
21
  logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
52
52
  tp_size: int,
53
53
  tp_rank: int,
54
54
  max_lora_rank: int,
55
- lora_weight_names: Set[str],
55
+ target_modules: Set[str],
56
56
  base_model: torch.nn.Module,
57
57
  ):
58
58
  self.base_hf_config: AutoConfig = base_hf_config
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
62
62
  self.tp_size: int = tp_size
63
63
  self.tp_rank: int = tp_rank
64
64
  self.max_lora_rank: int = max_lora_rank
65
- self.lora_weight_names: Set[str] = lora_weight_names
65
+ self.target_modules: Set[str] = target_modules
66
66
 
67
67
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
68
68
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
95
95
  """
96
96
  if config.r > self.max_lora_rank:
97
97
  return False
98
- weights = get_normalized_lora_weight_names(config.target_modules)
99
- return weights.issubset(self.lora_weight_names)
98
+ target_module_names = get_normalized_target_modules(config.target_modules)
99
+ return target_module_names.issubset(self.target_modules)
100
100
 
101
101
  if isinstance(config, LoRAConfig):
102
102
  return _can_support(config)
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
139
139
 
140
140
  def init_buffer(
141
141
  buffer: Dict[str, List[torch.Tensor]],
142
- lora_weight_names: Set[str],
142
+ target_modules: Set[str],
143
143
  get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
144
144
  ):
145
- for module_name in lora_weight_names:
145
+ for module_name in target_modules:
146
146
  lora_shape = get_lora_shape_fn(
147
147
  module_name, base_model, self.max_lora_rank
148
148
  )
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
157
157
 
158
158
  init_buffer(
159
159
  self.A_buffer,
160
- self.lora_weight_names,
160
+ self.target_modules,
161
161
  self.get_lora_A_shape,
162
162
  )
163
163
 
164
164
  init_buffer(
165
165
  self.B_buffer,
166
- self.lora_weight_names,
166
+ self.target_modules,
167
167
  self.get_lora_B_shape,
168
168
  )
169
169
 
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
242
242
  for layer_id in range(self.num_layer):
243
243
  layer_weights = lora_adapter.layers[layer_id].weights
244
244
  temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
245
- weight_name: None for weight_name in self.A_buffer
245
+ target_module: None for target_module in self.A_buffer
246
246
  }
247
247
  temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
248
- weight_name: None for weight_name in self.B_buffer
248
+ target_module: None for target_module in self.B_buffer
249
249
  }
250
250
  for name, weights in layer_weights.items():
251
- lora_weight_name = get_weight_name(name, self.lora_weight_names)
251
+ target_module = get_target_module_name(name, self.target_modules)
252
252
  if "lora_A" in name:
253
- temp_A_buffer[lora_weight_name] = weights
253
+ temp_A_buffer[target_module] = weights
254
254
  else:
255
- temp_B_buffer[lora_weight_name] = weights
255
+ temp_B_buffer[target_module] = weights
256
256
 
257
257
  if self.tp_size > 1:
258
258
  cur_layer_modules = lora_modules[layer_id]
259
259
  for module_name, module in cur_layer_modules.items():
260
- weight_name = get_weight_name(module_name, self.lora_weight_names)
260
+ target_module = get_target_module_name(
261
+ module_name, self.target_modules
262
+ )
261
263
 
262
- if temp_A_buffer[weight_name] is None:
264
+ if temp_A_buffer[target_module] is None:
263
265
  # Skip weight slicing if the weight is not present in the adapter
264
266
  continue
265
267
 
266
- temp_A_buffer[weight_name] = module.slice_lora_a_weights(
267
- temp_A_buffer[weight_name], self.tp_rank
268
+ temp_A_buffer[target_module] = module.slice_lora_a_weights(
269
+ temp_A_buffer[target_module], self.tp_rank
268
270
  )
269
- temp_B_buffer[weight_name] = module.slice_lora_b_weights(
270
- temp_B_buffer[weight_name], self.tp_rank
271
+ temp_B_buffer[target_module] = module.slice_lora_b_weights(
272
+ temp_B_buffer[target_module], self.tp_rank
271
273
  )
272
274
 
273
275
  for name, weights in temp_A_buffer.items():
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
282
284
  load_lora_weight_tensor(buffer_view, weights)
283
285
 
284
286
  def get_tensor(
285
- self, weight_name: str, layer_id: int, lora_type: LoRAType
287
+ self, target_module: str, layer_id: int, lora_type: LoRAType
286
288
  ) -> torch.Tensor:
287
289
  if lora_type == LoRAType.LORA_A:
288
- return self.A_buffer[weight_name][layer_id]
290
+ return self.A_buffer[target_module][layer_id]
289
291
 
290
- return self.B_buffer[weight_name][layer_id]
292
+ return self.B_buffer[target_module][layer_id]
291
293
 
292
294
  def get_buffer_id(self, lora_uid: str):
293
295
  return self.uid_to_buffer_id[lora_uid]
sglang/srt/lora/utils.py CHANGED
@@ -84,7 +84,7 @@ def get_hidden_dim(
84
84
  raise NotImplementedError()
85
85
 
86
86
 
87
- def get_normalized_lora_weight_names(
87
+ def get_normalized_target_modules(
88
88
  target_modules: Iterable[str],
89
89
  ) -> set[str]:
90
90
  """
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
100
100
 
101
101
  result = set()
102
102
  for name in target_modules:
103
- weight_name = params_mapping.get(name, name)
104
- result.add(weight_name)
103
+ normalized_name = params_mapping.get(name, name)
104
+ result.add(normalized_name)
105
105
  return result
106
106
 
107
107
 
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
116
116
  return stacked_rank[module_name] if module_name in stacked_rank else 1
117
117
 
118
118
 
119
- def get_weight_name(
120
- target_name: str, lora_weight_names: Tuple[Set[str]]
121
- ) -> Optional[str]:
119
+ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
122
120
  """
123
- Get the weight name in lora_weight_names that can match target_name.
121
+ Get the target module name in target_modules that can match full_module_name.
124
122
 
125
- If there is a weight name in lora_weight_names that can match target_name, return this name
123
+ If there is a target module name in target_modules that can match full_module_name, return this name
126
124
  Else raise ValueError.
127
125
  """
128
- for weight_name in lora_weight_names:
129
- if weight_name in target_name:
130
- return weight_name
126
+ for target_module in target_modules:
127
+ if target_module in full_module_name:
128
+ return target_module
131
129
  raise ValueError(
132
- f"Cannot find weight name for {target_name} in {lora_weight_names}"
130
+ f"Cannot find target module name for {full_module_name} in {target_modules}"
133
131
  )
134
132
 
135
133
 
@@ -26,6 +26,8 @@ if TYPE_CHECKING:
26
26
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
27
27
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
28
28
 
29
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
30
+ from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
29
31
 
30
32
  logger = logging.getLogger(__name__)
31
33
 
@@ -238,13 +240,14 @@ class HiCacheController:
238
240
  self.io_backend = io_backend
239
241
 
240
242
  self.enable_storage = False
243
+ self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
241
244
  # todo: move backend initialization to storage backend module
242
245
  if storage_backend is not None:
243
246
  self.storage_backend_type = storage_backend
244
247
  from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
245
248
 
246
249
  if storage_backend == "file":
247
- self.storage_backend = HiCacheFile()
250
+ self.storage_backend = HiCacheFile(is_mla=self.is_mla)
248
251
  self.get_hash_str = get_hash_str
249
252
  elif storage_backend == "nixl":
250
253
  from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
@@ -257,23 +260,26 @@ class HiCacheController:
257
260
  get_hash_str_mooncake,
258
261
  )
259
262
 
260
- self.storage_backend = MooncakeStore()
263
+ self.storage_backend = MooncakeStore(is_mla=self.is_mla)
261
264
  self.get_hash_str = get_hash_str_mooncake
262
265
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
263
266
  assert self.mem_pool_host.layout == "page_first"
264
267
  elif storage_backend == "hf3fs":
265
- from sglang.srt.distributed import get_tensor_model_parallel_rank
266
268
  from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
267
269
  HiCacheHF3FS,
268
270
  )
269
271
 
270
- rank = get_tensor_model_parallel_rank()
271
- bytes_per_page = (
272
- mem_pool_host.get_size_per_token() * mem_pool_host.page_size
273
- )
272
+ if self.mem_pool_host.layout == "page_first":
273
+ bytes_per_page = (
274
+ mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
275
+ )
276
+ elif self.mem_pool_host.layout == "layer_first":
277
+ bytes_per_page = (
278
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
279
+ )
274
280
  dtype = mem_pool_host.dtype
275
281
  self.storage_backend = HiCacheHF3FS.from_env_config(
276
- rank, bytes_per_page, dtype
282
+ bytes_per_page, dtype
277
283
  )
278
284
  self.get_hash_str = get_hash_str
279
285
  else:
@@ -296,6 +302,9 @@ class HiCacheController:
296
302
  self.prefetch_tp_group = torch.distributed.new_group(
297
303
  group_ranks, backend="gloo"
298
304
  )
305
+ self.prefetch_io_tp_group = torch.distributed.new_group(
306
+ group_ranks, backend="gloo"
307
+ )
299
308
  self.backup_tp_group = torch.distributed.new_group(
300
309
  group_ranks, backend="gloo"
301
310
  )
@@ -391,6 +400,15 @@ class HiCacheController:
391
400
  self.prefetch_thread.start()
392
401
  self.backup_thread.start()
393
402
 
403
+ @property
404
+ def backup_skip(self):
405
+ return (
406
+ self.is_mla
407
+ and get_tensor_model_parallel_rank() != 0
408
+ # todo: only support file and mooncake
409
+ and self.storage_backend_type in ["file", "mooncake"]
410
+ )
411
+
394
412
  def write(
395
413
  self,
396
414
  device_indices: torch.Tensor,
@@ -552,13 +570,34 @@ class HiCacheController:
552
570
  operation.mark_done()
553
571
  return operation.completed_tokens, operation.hash_value
554
572
 
573
+ def zerocopy_page_transfer(self, operation, batch_size=8):
574
+ hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
575
+ operation.hash_value, operation.host_indices
576
+ )
577
+ for i in range(0, len(hashes), batch_size):
578
+ page_hashes = hashes[i : i + batch_size]
579
+ page_dsts = dsts[i : i + batch_size]
580
+ page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
581
+ if page_data is None:
582
+ logger.warning(
583
+ f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
584
+ )
585
+ break
586
+ completed_tokens = operation.completed_tokens
587
+ if operation.increment(self.page_size * len(page_hashes)):
588
+ for i in range(len(page_hashes)):
589
+ completed_tokens += self.page_size
590
+ else:
591
+ break
592
+
555
593
  def generic_page_transfer(self, operation, batch_size=8):
556
594
  for i in range(0, len(operation.hash_value), batch_size):
557
595
  page_hashes = operation.hash_value[i : i + batch_size]
558
596
  # todo: zero copy
559
- dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
560
- page_hashes
561
- )
597
+ dummy_page_dst = [
598
+ self.mem_pool_host.get_dummy_flat_data_page()
599
+ for _ in range(len(page_hashes))
600
+ ]
562
601
  page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
563
602
  if page_data is None:
564
603
  logger.warning(
@@ -596,13 +635,16 @@ class HiCacheController:
596
635
  if self.is_mooncake_backend():
597
636
  self.mooncake_page_transfer(operation)
598
637
  elif self.storage_backend_type == "hf3fs":
599
- self.generic_page_transfer(operation, batch_size=128)
638
+ if self.mem_pool_host.layout == "page_first":
639
+ self.zerocopy_page_transfer(operation, batch_size=128)
640
+ elif self.mem_pool_host.layout == "layer_first":
641
+ self.generic_page_transfer(operation, batch_size=128)
600
642
  else:
601
643
  self.generic_page_transfer(operation)
602
644
 
603
645
  if self.tp_world_size > 1:
604
646
  # to ensure all TP workers release the host memory at the same time
605
- torch.distributed.barrier(group=self.prefetch_tp_group)
647
+ torch.distributed.barrier(group=self.prefetch_io_tp_group)
606
648
  # operation terminated by controller, release pre-allocated memory
607
649
  self.mem_pool_host.free(
608
650
  operation.host_indices[operation.completed_tokens :]
@@ -713,6 +755,19 @@ class HiCacheController:
713
755
  self.backup_queue.put(operation)
714
756
  return operation.id
715
757
 
758
+ def zerocopy_page_backup(self, operation, batch_size=8):
759
+ hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
760
+ operation.hash_value, operation.host_indices
761
+ )
762
+ for i in range(0, len(hashes), batch_size):
763
+ page_hashes = hashes[i : i + batch_size]
764
+ page_data = dsts[i : i + batch_size]
765
+ success = self.storage_backend.batch_set(page_hashes, page_data)
766
+ if not success:
767
+ logger.warning(f"Failed to write page {page_hashes} to storage.")
768
+ break
769
+ operation.completed_tokens += self.page_size * len(page_hashes)
770
+
716
771
  def generic_page_backup(self, operation, batch_size=8):
717
772
  for i in range(0, len(operation.hash_value), batch_size):
718
773
  page_hashes = operation.hash_value[i : i + batch_size]
@@ -764,14 +819,20 @@ class HiCacheController:
764
819
  if operation is None:
765
820
  continue
766
821
 
767
- if self.is_mooncake_backend():
768
- self.mooncake_page_backup(operation)
769
- elif self.storage_backend_type == "hf3fs":
770
- self.generic_page_backup(operation, batch_size=128)
822
+ if not self.backup_skip:
823
+ if self.is_mooncake_backend():
824
+ self.mooncake_page_backup(operation)
825
+ elif self.storage_backend_type == "hf3fs":
826
+ if self.mem_pool_host.layout == "page_first":
827
+ self.zerocopy_page_backup(operation, batch_size=128)
828
+ elif self.mem_pool_host.layout == "layer_first":
829
+ self.generic_page_backup(operation, batch_size=128)
830
+ else:
831
+ self.generic_page_backup(operation)
832
+ min_completed_tokens = operation.completed_tokens
771
833
  else:
772
- self.generic_page_backup(operation)
834
+ min_completed_tokens = len(operation.token_ids)
773
835
 
774
- min_completed_tokens = operation.completed_tokens
775
836
  if self.tp_world_size > 1:
776
837
  completed_tokens_tensor = torch.tensor(
777
838
  min_completed_tokens, dtype=torch.int
@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
31
31
  BatchMultimodalOut,
32
32
  BatchStrOut,
33
33
  BatchTokenIDOut,
34
+ FreezeGCReq,
34
35
  )
35
36
  from sglang.srt.server_args import PortArgs, ServerArgs
36
37
  from sglang.srt.utils import (
37
38
  configure_logger,
39
+ freeze_gc,
38
40
  get_zmq_socket,
39
41
  kill_itself_when_parent_died,
40
42
  )
@@ -100,6 +102,7 @@ class DetokenizerManager:
100
102
  (BatchEmbeddingOut, self.handle_batch_embedding_out),
101
103
  (BatchTokenIDOut, self.handle_batch_token_id_out),
102
104
  (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
105
+ (FreezeGCReq, self.handle_freeze_gc_req),
103
106
  ]
104
107
  )
105
108
 
@@ -108,7 +111,8 @@ class DetokenizerManager:
108
111
  while True:
109
112
  recv_obj = self.recv_from_scheduler.recv_pyobj()
110
113
  output = self._request_dispatcher(recv_obj)
111
- self.send_to_tokenizer.send_pyobj(output)
114
+ if output is not None:
115
+ self.send_to_tokenizer.send_pyobj(output)
112
116
 
113
117
  def trim_matched_stop(
114
118
  self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
@@ -216,7 +220,7 @@ class DetokenizerManager:
216
220
  rids=recv_obj.rids,
217
221
  finished_reasons=recv_obj.finished_reasons,
218
222
  output_strs=output_strs,
219
- output_ids=recv_obj.output_ids,
223
+ output_ids=recv_obj.decode_ids,
220
224
  prompt_tokens=recv_obj.prompt_tokens,
221
225
  completion_tokens=recv_obj.completion_tokens,
222
226
  cached_tokens=recv_obj.cached_tokens,
@@ -247,6 +251,10 @@ class DetokenizerManager:
247
251
  cached_tokens=recv_obj.cached_tokens,
248
252
  )
249
253
 
254
+ def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
255
+ freeze_gc("Detokenizer Manager")
256
+ return None
257
+
250
258
 
251
259
  class LimitedCapacityDict(OrderedDict):
252
260
  def __init__(self, capacity: int, *args, **kwargs):
@@ -612,6 +612,8 @@ class EmbeddingReqInput:
612
612
 
613
613
  if self.sampling_params is None:
614
614
  self.sampling_params = [{}] * self.batch_size
615
+ elif isinstance(self.sampling_params, dict):
616
+ self.sampling_params = [self.sampling_params] * self.batch_size
615
617
  for i in range(self.batch_size):
616
618
  self.sampling_params[i]["max_new_tokens"] = 0
617
619
 
@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
660
662
  token_type_ids: List[int]
661
663
  # Dummy sampling params for compatibility
662
664
  sampling_params: SamplingParams
665
+ # For data parallel rank routing
666
+ data_parallel_rank: Optional[int] = None
663
667
  # For dp balance
664
668
  dp_balance_id: int = -1
665
669
 
@@ -798,6 +802,8 @@ class UpdateWeightFromDiskReqInput:
798
802
  load_format: Optional[str] = None
799
803
  # Whether to abort all requests before updating weights
800
804
  abort_all_requests: bool = False
805
+ # Optional: Update weight version along with weights
806
+ weight_version: Optional[str] = None
801
807
 
802
808
 
803
809
  @dataclass
@@ -819,6 +825,8 @@ class UpdateWeightsFromDistributedReqInput:
819
825
  flush_cache: bool = True
820
826
  # Whether to abort all requests before updating weights
821
827
  abort_all_requests: bool = False
828
+ # Optional: Update weight version along with weights
829
+ weight_version: Optional[str] = None
822
830
 
823
831
 
824
832
  @dataclass
@@ -842,6 +850,8 @@ class UpdateWeightsFromTensorReqInput:
842
850
  flush_cache: bool = True
843
851
  # Whether to abort all requests before updating weights
844
852
  abort_all_requests: bool = False
853
+ # Optional: Update weight version along with weights
854
+ weight_version: Optional[str] = None
845
855
 
846
856
 
847
857
  @dataclass
@@ -872,6 +882,14 @@ class InitWeightsUpdateGroupReqOutput:
872
882
  message: str
873
883
 
874
884
 
885
+ @dataclass
886
+ class UpdateWeightVersionReqInput:
887
+ # The new weight version
888
+ new_version: str
889
+ # Whether to abort all running requests before updating
890
+ abort_all_requests: bool = True
891
+
892
+
875
893
  @dataclass
876
894
  class GetWeightsByNameReqInput:
877
895
  name: str
@@ -987,6 +1005,11 @@ class ProfileReqOutput:
987
1005
  message: str
988
1006
 
989
1007
 
1008
+ @dataclass
1009
+ class FreezeGCReq:
1010
+ pass
1011
+
1012
+
990
1013
  @dataclass
991
1014
  class ConfigureLoggingReq:
992
1015
  log_requests: Optional[bool] = None
@@ -560,7 +560,7 @@ def embed_mm_inputs(
560
560
  ]
561
561
  items_size[i + 1] = len(mm_items)
562
562
  items_offsets.append(
563
- flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
563
+ flatten_nested_list([item.offsets for item in mm_items])
564
564
  )
565
565
  items_size = torch.cumsum(items_size, dim=0).tolist()
566
566
 
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
52
  ScheduleBatchDisaggregationDecodeMixin,
53
53
  )
54
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
+ from sglang.srt.layers.moe import is_tbo_enabled
55
56
  from sglang.srt.mem_cache.allocator import (
56
57
  BaseTokenToKVPoolAllocator,
57
58
  SWATokenToKVPoolAllocator,
@@ -83,19 +84,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
83
84
  "chunked_prefill_size",
84
85
  "device",
85
86
  "disable_chunked_prefix_cache",
87
+ "disable_flashinfer_cutlass_moe_fp4_allgather",
86
88
  "disable_radix_cache",
87
- "enable_dp_attention",
88
- "enable_two_batch_overlap",
89
- "tbo_token_distribution_threshold",
90
89
  "enable_dp_lm_head",
91
- "moe_a2a_backend",
92
- "deepep_mode",
93
- "enable_flashinfer_cutlass_moe",
94
- "enable_flashinfer_trtllm_moe",
90
+ "flashinfer_mxfp4_moe_precision",
95
91
  "enable_flashinfer_allreduce_fusion",
96
92
  "moe_dense_tp_size",
97
93
  "ep_dispatch_algorithm",
98
- "deepep_config",
99
94
  "ep_num_redundant_experts",
100
95
  "enable_nan_detection",
101
96
  "flashinfer_mla_disable_ragged",
@@ -108,11 +103,11 @@ GLOBAL_SERVER_ARGS_KEYS = [
108
103
  "triton_attention_reduce_in_fp32",
109
104
  "num_reserved_decode_tokens",
110
105
  "weight_loader_disable_mmap",
111
- "enable_triton_kernel_moe",
112
- "enable_flashinfer_mxfp4_moe",
113
106
  "enable_multimodal",
114
107
  "enable_symm_mem",
115
108
  "quantization",
109
+ "enable_custom_logit_processor",
110
+ "disaggregation_mode",
116
111
  ]
117
112
 
118
113
  # Put some global args for easy access
@@ -909,12 +904,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
909
904
  spec_algorithm: SpeculativeAlgorithm = None
910
905
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
911
906
 
912
- # Enable custom logit processor
913
- enable_custom_logit_processor: bool = False
914
-
915
907
  # Whether to return hidden states
916
908
  return_hidden_states: bool = False
917
909
 
910
+ # Whether this batch is prefill-only (no token generation needed)
911
+ is_prefill_only: bool = False
912
+
918
913
  # hicache pointer for synchronizing data loading from CPU to GPU
919
914
  hicache_consumer_index: int = 0
920
915
 
@@ -928,7 +923,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
928
923
  model_config: ModelConfig,
929
924
  enable_overlap: bool,
930
925
  spec_algorithm: SpeculativeAlgorithm,
931
- enable_custom_logit_processor: bool,
932
926
  chunked_req: Optional[Req] = None,
933
927
  ):
934
928
  return_logprob = any(req.return_logprob for req in reqs)
@@ -955,8 +949,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
955
949
  has_grammar=any(req.grammar for req in reqs),
956
950
  device=req_to_token_pool.device,
957
951
  spec_algorithm=spec_algorithm,
958
- enable_custom_logit_processor=enable_custom_logit_processor,
959
952
  return_hidden_states=any(req.return_hidden_states for req in reqs),
953
+ is_prefill_only=all(
954
+ req.sampling_params.max_new_tokens == 0 for req in reqs
955
+ ),
960
956
  chunked_req=chunked_req,
961
957
  )
962
958
 
@@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1009
1005
  extend_num_tokens: int,
1010
1006
  backup_state: bool = False,
1011
1007
  ):
1008
+ # Over estimate the number of tokens: assume each request needs a new page.
1012
1009
  num_tokens = (
1013
1010
  extend_num_tokens
1014
1011
  + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
@@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1041
1038
  last_loc: torch.Tensor,
1042
1039
  backup_state: bool = False,
1043
1040
  ):
1041
+ # Over estimate the number of tokens: assume each request needs a new page.
1044
1042
  num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1045
-
1046
1043
  self._evict_tree_cache_if_needed(num_tokens)
1047
1044
 
1048
1045
  if backup_state:
@@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1721
1718
  extend_prefix_lens = self.prefix_lens
1722
1719
  extend_logprob_start_lens = self.extend_logprob_start_lens
1723
1720
 
1724
- if self.forward_mode.is_decode_or_idle():
1725
- attention_backend_str = global_server_args_dict["decode_attention_backend"]
1726
- else:
1727
- attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1728
- # Create seq_lens_cpu when needed
1729
- if (
1730
- attention_backend_str
1731
- in [
1732
- "fa3",
1733
- "flashinfer",
1734
- "flashmla",
1735
- "cutlass_mla",
1736
- "ascend",
1737
- "trtllm_mha",
1738
- "aiter",
1739
- ]
1740
- or global_server_args_dict["enable_two_batch_overlap"]
1741
- ):
1742
- seq_lens_cpu = (
1743
- seq_lens_cpu_cache
1744
- if seq_lens_cpu_cache is not None
1745
- else self.seq_lens.cpu()
1746
- )
1747
- else:
1748
- seq_lens_cpu = None
1749
-
1750
1721
  if self.sampling_info:
1751
1722
  if self.has_grammar:
1752
1723
  self.sampling_info.grammars = [req.grammar for req in self.reqs]
1753
1724
  else:
1754
1725
  self.sampling_info.grammars = None
1755
1726
 
1727
+ seq_lens_cpu = (
1728
+ seq_lens_cpu_cache
1729
+ if seq_lens_cpu_cache is not None
1730
+ else self.seq_lens.cpu()
1731
+ )
1732
+
1756
1733
  global bid
1757
1734
  bid += 1
1758
1735
  return ModelWorkerBatch(
@@ -1815,18 +1792,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1815
1792
  return_logprob=self.return_logprob,
1816
1793
  decoding_reqs=self.decoding_reqs,
1817
1794
  spec_algorithm=self.spec_algorithm,
1818
- enable_custom_logit_processor=self.enable_custom_logit_processor,
1819
1795
  global_num_tokens=self.global_num_tokens,
1820
1796
  global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1821
1797
  can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1822
1798
  is_extend_in_batch=self.is_extend_in_batch,
1799
+ is_prefill_only=self.is_prefill_only,
1823
1800
  )
1824
1801
 
1825
- def _evict_tree_cache_if_needed(
1826
- self,
1827
- num_tokens: int,
1828
- ) -> None:
1829
- if isinstance(self.tree_cache, SWAChunkCache):
1802
+ def _evict_tree_cache_if_needed(self, num_tokens: int):
1803
+ if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
1830
1804
  return
1831
1805
 
1832
1806
  if self.is_hybrid: