sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,9 @@ from sglang.srt.lora.utils import (
13
13
  ROW_PARALLELISM_LINEAR_LORA_NAMES,
14
14
  LoRAType,
15
15
  get_hidden_dim,
16
- get_normalized_lora_weight_names,
16
+ get_normalized_target_modules,
17
17
  get_stacked_multiply,
18
- get_weight_name,
18
+ get_target_module_name,
19
19
  )
20
20
 
21
21
  logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
52
52
  tp_size: int,
53
53
  tp_rank: int,
54
54
  max_lora_rank: int,
55
- lora_weight_names: Set[str],
55
+ target_modules: Set[str],
56
56
  base_model: torch.nn.Module,
57
57
  ):
58
58
  self.base_hf_config: AutoConfig = base_hf_config
@@ -62,7 +62,7 @@ class LoRAMemoryPool:
62
62
  self.tp_size: int = tp_size
63
63
  self.tp_rank: int = tp_rank
64
64
  self.max_lora_rank: int = max_lora_rank
65
- self.lora_weight_names: Set[str] = lora_weight_names
65
+ self.target_modules: Set[str] = target_modules
66
66
 
67
67
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
68
68
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -95,8 +95,8 @@ class LoRAMemoryPool:
95
95
  """
96
96
  if config.r > self.max_lora_rank:
97
97
  return False
98
- weights = get_normalized_lora_weight_names(config.target_modules)
99
- return weights.issubset(self.lora_weight_names)
98
+ target_module_names = get_normalized_target_modules(config.target_modules)
99
+ return target_module_names.issubset(self.target_modules)
100
100
 
101
101
  if isinstance(config, LoRAConfig):
102
102
  return _can_support(config)
@@ -139,10 +139,10 @@ class LoRAMemoryPool:
139
139
 
140
140
  def init_buffer(
141
141
  buffer: Dict[str, List[torch.Tensor]],
142
- lora_weight_names: Set[str],
142
+ target_modules: Set[str],
143
143
  get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
144
144
  ):
145
- for module_name in lora_weight_names:
145
+ for module_name in target_modules:
146
146
  lora_shape = get_lora_shape_fn(
147
147
  module_name, base_model, self.max_lora_rank
148
148
  )
@@ -157,13 +157,13 @@ class LoRAMemoryPool:
157
157
 
158
158
  init_buffer(
159
159
  self.A_buffer,
160
- self.lora_weight_names,
160
+ self.target_modules,
161
161
  self.get_lora_A_shape,
162
162
  )
163
163
 
164
164
  init_buffer(
165
165
  self.B_buffer,
166
- self.lora_weight_names,
166
+ self.target_modules,
167
167
  self.get_lora_B_shape,
168
168
  )
169
169
 
@@ -242,32 +242,34 @@ class LoRAMemoryPool:
242
242
  for layer_id in range(self.num_layer):
243
243
  layer_weights = lora_adapter.layers[layer_id].weights
244
244
  temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
245
- weight_name: None for weight_name in self.A_buffer
245
+ target_module: None for target_module in self.A_buffer
246
246
  }
247
247
  temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
248
- weight_name: None for weight_name in self.B_buffer
248
+ target_module: None for target_module in self.B_buffer
249
249
  }
250
250
  for name, weights in layer_weights.items():
251
- lora_weight_name = get_weight_name(name, self.lora_weight_names)
251
+ target_module = get_target_module_name(name, self.target_modules)
252
252
  if "lora_A" in name:
253
- temp_A_buffer[lora_weight_name] = weights
253
+ temp_A_buffer[target_module] = weights
254
254
  else:
255
- temp_B_buffer[lora_weight_name] = weights
255
+ temp_B_buffer[target_module] = weights
256
256
 
257
257
  if self.tp_size > 1:
258
258
  cur_layer_modules = lora_modules[layer_id]
259
259
  for module_name, module in cur_layer_modules.items():
260
- weight_name = get_weight_name(module_name, self.lora_weight_names)
260
+ target_module = get_target_module_name(
261
+ module_name, self.target_modules
262
+ )
261
263
 
262
- if temp_A_buffer[weight_name] is None:
264
+ if temp_A_buffer[target_module] is None:
263
265
  # Skip weight slicing if the weight is not present in the adapter
264
266
  continue
265
267
 
266
- temp_A_buffer[weight_name] = module.slice_lora_a_weights(
267
- temp_A_buffer[weight_name], self.tp_rank
268
+ temp_A_buffer[target_module] = module.slice_lora_a_weights(
269
+ temp_A_buffer[target_module], self.tp_rank
268
270
  )
269
- temp_B_buffer[weight_name] = module.slice_lora_b_weights(
270
- temp_B_buffer[weight_name], self.tp_rank
271
+ temp_B_buffer[target_module] = module.slice_lora_b_weights(
272
+ temp_B_buffer[target_module], self.tp_rank
271
273
  )
272
274
 
273
275
  for name, weights in temp_A_buffer.items():
@@ -282,12 +284,12 @@ class LoRAMemoryPool:
282
284
  load_lora_weight_tensor(buffer_view, weights)
283
285
 
284
286
  def get_tensor(
285
- self, weight_name: str, layer_id: int, lora_type: LoRAType
287
+ self, target_module: str, layer_id: int, lora_type: LoRAType
286
288
  ) -> torch.Tensor:
287
289
  if lora_type == LoRAType.LORA_A:
288
- return self.A_buffer[weight_name][layer_id]
290
+ return self.A_buffer[target_module][layer_id]
289
291
 
290
- return self.B_buffer[weight_name][layer_id]
292
+ return self.B_buffer[target_module][layer_id]
291
293
 
292
294
  def get_buffer_id(self, lora_uid: str):
293
295
  return self.uid_to_buffer_id[lora_uid]
sglang/srt/lora/utils.py CHANGED
@@ -84,7 +84,7 @@ def get_hidden_dim(
84
84
  raise NotImplementedError()
85
85
 
86
86
 
87
- def get_normalized_lora_weight_names(
87
+ def get_normalized_target_modules(
88
88
  target_modules: Iterable[str],
89
89
  ) -> set[str]:
90
90
  """
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
100
100
 
101
101
  result = set()
102
102
  for name in target_modules:
103
- weight_name = params_mapping.get(name, name)
104
- result.add(weight_name)
103
+ normalized_name = params_mapping.get(name, name)
104
+ result.add(normalized_name)
105
105
  return result
106
106
 
107
107
 
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
116
116
  return stacked_rank[module_name] if module_name in stacked_rank else 1
117
117
 
118
118
 
119
- def get_weight_name(
120
- target_name: str, lora_weight_names: Tuple[Set[str]]
121
- ) -> Optional[str]:
119
+ def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
122
120
  """
123
- Get the weight name in lora_weight_names that can match target_name.
121
+ Get the target module name in target_modules that can match full_module_name.
124
122
 
125
- If there is a weight name in lora_weight_names that can match target_name, return this name
123
+ If there is a target module name in target_modules that can match full_module_name, return this name
126
124
  Else raise ValueError.
127
125
  """
128
- for weight_name in lora_weight_names:
129
- if weight_name in target_name:
130
- return weight_name
126
+ for target_module in target_modules:
127
+ if target_module in full_module_name:
128
+ return target_module
131
129
  raise ValueError(
132
- f"Cannot find weight name for {target_name} in {lora_weight_names}"
130
+ f"Cannot find target module name for {full_module_name} in {target_modules}"
133
131
  )
134
132
 
135
133
 
@@ -26,6 +26,8 @@ if TYPE_CHECKING:
26
26
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
27
27
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
28
28
 
29
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
30
+ from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
29
31
 
30
32
  logger = logging.getLogger(__name__)
31
33
 
@@ -238,13 +240,14 @@ class HiCacheController:
238
240
  self.io_backend = io_backend
239
241
 
240
242
  self.enable_storage = False
243
+ self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
241
244
  # todo: move backend initialization to storage backend module
242
245
  if storage_backend is not None:
243
246
  self.storage_backend_type = storage_backend
244
247
  from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
245
248
 
246
249
  if storage_backend == "file":
247
- self.storage_backend = HiCacheFile()
250
+ self.storage_backend = HiCacheFile(is_mla=self.is_mla)
248
251
  self.get_hash_str = get_hash_str
249
252
  elif storage_backend == "nixl":
250
253
  from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
@@ -257,23 +260,26 @@ class HiCacheController:
257
260
  get_hash_str_mooncake,
258
261
  )
259
262
 
260
- self.storage_backend = MooncakeStore()
263
+ self.storage_backend = MooncakeStore(is_mla=self.is_mla)
261
264
  self.get_hash_str = get_hash_str_mooncake
262
265
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
263
266
  assert self.mem_pool_host.layout == "page_first"
264
267
  elif storage_backend == "hf3fs":
265
- from sglang.srt.distributed import get_tensor_model_parallel_rank
266
268
  from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
267
269
  HiCacheHF3FS,
268
270
  )
269
271
 
270
- rank = get_tensor_model_parallel_rank()
271
- bytes_per_page = (
272
- mem_pool_host.get_size_per_token() * mem_pool_host.page_size
273
- )
272
+ if self.mem_pool_host.layout == "page_first":
273
+ bytes_per_page = (
274
+ mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
275
+ )
276
+ elif self.mem_pool_host.layout == "layer_first":
277
+ bytes_per_page = (
278
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
279
+ )
274
280
  dtype = mem_pool_host.dtype
275
281
  self.storage_backend = HiCacheHF3FS.from_env_config(
276
- rank, bytes_per_page, dtype
282
+ bytes_per_page, dtype
277
283
  )
278
284
  self.get_hash_str = get_hash_str
279
285
  else:
@@ -394,6 +400,15 @@ class HiCacheController:
394
400
  self.prefetch_thread.start()
395
401
  self.backup_thread.start()
396
402
 
403
+ @property
404
+ def backup_skip(self):
405
+ return (
406
+ self.is_mla
407
+ and get_tensor_model_parallel_rank() != 0
408
+ # todo: only support file and mooncake
409
+ and self.storage_backend_type in ["file", "mooncake"]
410
+ )
411
+
397
412
  def write(
398
413
  self,
399
414
  device_indices: torch.Tensor,
@@ -555,13 +570,34 @@ class HiCacheController:
555
570
  operation.mark_done()
556
571
  return operation.completed_tokens, operation.hash_value
557
572
 
573
+ def zerocopy_page_transfer(self, operation, batch_size=8):
574
+ hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
575
+ operation.hash_value, operation.host_indices
576
+ )
577
+ for i in range(0, len(hashes), batch_size):
578
+ page_hashes = hashes[i : i + batch_size]
579
+ page_dsts = dsts[i : i + batch_size]
580
+ page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
581
+ if page_data is None:
582
+ logger.warning(
583
+ f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
584
+ )
585
+ break
586
+ completed_tokens = operation.completed_tokens
587
+ if operation.increment(self.page_size * len(page_hashes)):
588
+ for i in range(len(page_hashes)):
589
+ completed_tokens += self.page_size
590
+ else:
591
+ break
592
+
558
593
  def generic_page_transfer(self, operation, batch_size=8):
559
594
  for i in range(0, len(operation.hash_value), batch_size):
560
595
  page_hashes = operation.hash_value[i : i + batch_size]
561
596
  # todo: zero copy
562
- dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
563
- page_hashes
564
- )
597
+ dummy_page_dst = [
598
+ self.mem_pool_host.get_dummy_flat_data_page()
599
+ for _ in range(len(page_hashes))
600
+ ]
565
601
  page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
566
602
  if page_data is None:
567
603
  logger.warning(
@@ -599,7 +635,10 @@ class HiCacheController:
599
635
  if self.is_mooncake_backend():
600
636
  self.mooncake_page_transfer(operation)
601
637
  elif self.storage_backend_type == "hf3fs":
602
- self.generic_page_transfer(operation, batch_size=128)
638
+ if self.mem_pool_host.layout == "page_first":
639
+ self.zerocopy_page_transfer(operation, batch_size=128)
640
+ elif self.mem_pool_host.layout == "layer_first":
641
+ self.generic_page_transfer(operation, batch_size=128)
603
642
  else:
604
643
  self.generic_page_transfer(operation)
605
644
 
@@ -716,6 +755,19 @@ class HiCacheController:
716
755
  self.backup_queue.put(operation)
717
756
  return operation.id
718
757
 
758
+ def zerocopy_page_backup(self, operation, batch_size=8):
759
+ hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
760
+ operation.hash_value, operation.host_indices
761
+ )
762
+ for i in range(0, len(hashes), batch_size):
763
+ page_hashes = hashes[i : i + batch_size]
764
+ page_data = dsts[i : i + batch_size]
765
+ success = self.storage_backend.batch_set(page_hashes, page_data)
766
+ if not success:
767
+ logger.warning(f"Failed to write page {page_hashes} to storage.")
768
+ break
769
+ operation.completed_tokens += self.page_size * len(page_hashes)
770
+
719
771
  def generic_page_backup(self, operation, batch_size=8):
720
772
  for i in range(0, len(operation.hash_value), batch_size):
721
773
  page_hashes = operation.hash_value[i : i + batch_size]
@@ -767,14 +819,20 @@ class HiCacheController:
767
819
  if operation is None:
768
820
  continue
769
821
 
770
- if self.is_mooncake_backend():
771
- self.mooncake_page_backup(operation)
772
- elif self.storage_backend_type == "hf3fs":
773
- self.generic_page_backup(operation, batch_size=128)
822
+ if not self.backup_skip:
823
+ if self.is_mooncake_backend():
824
+ self.mooncake_page_backup(operation)
825
+ elif self.storage_backend_type == "hf3fs":
826
+ if self.mem_pool_host.layout == "page_first":
827
+ self.zerocopy_page_backup(operation, batch_size=128)
828
+ elif self.mem_pool_host.layout == "layer_first":
829
+ self.generic_page_backup(operation, batch_size=128)
830
+ else:
831
+ self.generic_page_backup(operation)
832
+ min_completed_tokens = operation.completed_tokens
774
833
  else:
775
- self.generic_page_backup(operation)
834
+ min_completed_tokens = len(operation.token_ids)
776
835
 
777
- min_completed_tokens = operation.completed_tokens
778
836
  if self.tp_world_size > 1:
779
837
  completed_tokens_tensor = torch.tensor(
780
838
  min_completed_tokens, dtype=torch.int
@@ -31,10 +31,12 @@ from sglang.srt.managers.io_struct import (
31
31
  BatchMultimodalOut,
32
32
  BatchStrOut,
33
33
  BatchTokenIDOut,
34
+ FreezeGCReq,
34
35
  )
35
36
  from sglang.srt.server_args import PortArgs, ServerArgs
36
37
  from sglang.srt.utils import (
37
38
  configure_logger,
39
+ freeze_gc,
38
40
  get_zmq_socket,
39
41
  kill_itself_when_parent_died,
40
42
  )
@@ -100,6 +102,7 @@ class DetokenizerManager:
100
102
  (BatchEmbeddingOut, self.handle_batch_embedding_out),
101
103
  (BatchTokenIDOut, self.handle_batch_token_id_out),
102
104
  (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
105
+ (FreezeGCReq, self.handle_freeze_gc_req),
103
106
  ]
104
107
  )
105
108
 
@@ -108,7 +111,8 @@ class DetokenizerManager:
108
111
  while True:
109
112
  recv_obj = self.recv_from_scheduler.recv_pyobj()
110
113
  output = self._request_dispatcher(recv_obj)
111
- self.send_to_tokenizer.send_pyobj(output)
114
+ if output is not None:
115
+ self.send_to_tokenizer.send_pyobj(output)
112
116
 
113
117
  def trim_matched_stop(
114
118
  self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
@@ -216,7 +220,7 @@ class DetokenizerManager:
216
220
  rids=recv_obj.rids,
217
221
  finished_reasons=recv_obj.finished_reasons,
218
222
  output_strs=output_strs,
219
- output_ids=recv_obj.output_ids,
223
+ output_ids=recv_obj.decode_ids,
220
224
  prompt_tokens=recv_obj.prompt_tokens,
221
225
  completion_tokens=recv_obj.completion_tokens,
222
226
  cached_tokens=recv_obj.cached_tokens,
@@ -247,6 +251,10 @@ class DetokenizerManager:
247
251
  cached_tokens=recv_obj.cached_tokens,
248
252
  )
249
253
 
254
+ def handle_freeze_gc_req(self, recv_req: FreezeGCReq):
255
+ freeze_gc("Detokenizer Manager")
256
+ return None
257
+
250
258
 
251
259
  class LimitedCapacityDict(OrderedDict):
252
260
  def __init__(self, capacity: int, *args, **kwargs):
@@ -612,6 +612,8 @@ class EmbeddingReqInput:
612
612
 
613
613
  if self.sampling_params is None:
614
614
  self.sampling_params = [{}] * self.batch_size
615
+ elif isinstance(self.sampling_params, dict):
616
+ self.sampling_params = [self.sampling_params] * self.batch_size
615
617
  for i in range(self.batch_size):
616
618
  self.sampling_params[i]["max_new_tokens"] = 0
617
619
 
@@ -660,6 +662,8 @@ class TokenizedEmbeddingReqInput:
660
662
  token_type_ids: List[int]
661
663
  # Dummy sampling params for compatibility
662
664
  sampling_params: SamplingParams
665
+ # For data parallel rank routing
666
+ data_parallel_rank: Optional[int] = None
663
667
  # For dp balance
664
668
  dp_balance_id: int = -1
665
669
 
@@ -1001,6 +1005,11 @@ class ProfileReqOutput:
1001
1005
  message: str
1002
1006
 
1003
1007
 
1008
+ @dataclass
1009
+ class FreezeGCReq:
1010
+ pass
1011
+
1012
+
1004
1013
  @dataclass
1005
1014
  class ConfigureLoggingReq:
1006
1015
  log_requests: Optional[bool] = None
@@ -560,7 +560,7 @@ def embed_mm_inputs(
560
560
  ]
561
561
  items_size[i + 1] = len(mm_items)
562
562
  items_offsets.append(
563
- flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
563
+ flatten_nested_list([item.offsets for item in mm_items])
564
564
  )
565
565
  items_size = torch.cumsum(items_size, dim=0).tolist()
566
566
 
@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
52
  ScheduleBatchDisaggregationDecodeMixin,
53
53
  )
54
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
+ from sglang.srt.layers.moe import is_tbo_enabled
55
56
  from sglang.srt.mem_cache.allocator import (
56
57
  BaseTokenToKVPoolAllocator,
57
58
  SWATokenToKVPoolAllocator,
@@ -83,18 +84,13 @@ GLOBAL_SERVER_ARGS_KEYS = [
83
84
  "chunked_prefill_size",
84
85
  "device",
85
86
  "disable_chunked_prefix_cache",
87
+ "disable_flashinfer_cutlass_moe_fp4_allgather",
86
88
  "disable_radix_cache",
87
- "enable_two_batch_overlap",
88
- "tbo_token_distribution_threshold",
89
89
  "enable_dp_lm_head",
90
- "moe_a2a_backend",
91
- "deepep_mode",
92
- "enable_flashinfer_cutlass_moe",
93
- "enable_flashinfer_trtllm_moe",
90
+ "flashinfer_mxfp4_moe_precision",
94
91
  "enable_flashinfer_allreduce_fusion",
95
92
  "moe_dense_tp_size",
96
93
  "ep_dispatch_algorithm",
97
- "deepep_config",
98
94
  "ep_num_redundant_experts",
99
95
  "enable_nan_detection",
100
96
  "flashinfer_mla_disable_ragged",
@@ -107,12 +103,11 @@ GLOBAL_SERVER_ARGS_KEYS = [
107
103
  "triton_attention_reduce_in_fp32",
108
104
  "num_reserved_decode_tokens",
109
105
  "weight_loader_disable_mmap",
110
- "enable_triton_kernel_moe",
111
- "enable_flashinfer_mxfp4_moe",
112
106
  "enable_multimodal",
113
107
  "enable_symm_mem",
114
108
  "quantization",
115
109
  "enable_custom_logit_processor",
110
+ "disaggregation_mode",
116
111
  ]
117
112
 
118
113
  # Put some global args for easy access
@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
64
64
  )
65
65
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
66
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
- from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
67
+ from sglang.srt.layers.moe import initialize_moe_config
68
68
  from sglang.srt.managers.io_struct import (
69
69
  AbortReq,
70
70
  CloseSessionReqInput,
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
72
72
  ExpertDistributionReqOutput,
73
73
  FlushCacheReqInput,
74
74
  FlushCacheReqOutput,
75
+ FreezeGCReq,
75
76
  GetInternalStateReq,
76
77
  GetInternalStateReqOutput,
77
78
  GetWeightsByNameReqInput,
@@ -145,6 +146,7 @@ from sglang.srt.utils import (
145
146
  configure_gc_logger,
146
147
  configure_logger,
147
148
  disable_request_logging,
149
+ freeze_gc,
148
150
  get_available_gpu_memory,
149
151
  get_bool_env_var,
150
152
  get_zmq_socket,
@@ -245,6 +247,9 @@ class Scheduler(
245
247
  )
246
248
  )
247
249
 
250
+ # Init model config
251
+ self.model_config = ModelConfig.from_server_args(server_args)
252
+
248
253
  # Init inter-process communication
249
254
  context = zmq.Context(2)
250
255
  self.idle_sleeper = None
@@ -292,6 +297,9 @@ class Scheduler(
292
297
  # Init tokenizer
293
298
  self.init_tokenizer()
294
299
 
300
+ # Init moe config
301
+ self.init_moe_config()
302
+
295
303
  # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
296
304
  if self.server_args.reasoning_parser and self.tokenizer:
297
305
  reasoning_parser = ReasoningParser(
@@ -518,6 +526,7 @@ class Scheduler(
518
526
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
519
527
  (SlowDownReqInput, self.slow_down),
520
528
  (ProfileReq, self.profile),
529
+ (FreezeGCReq, self.handle_freeze_gc),
521
530
  (GetInternalStateReq, self.get_internal_state),
522
531
  (SetInternalStateReq, self.set_internal_state),
523
532
  (RpcReqInput, self.handle_rpc_request),
@@ -538,8 +547,6 @@ class Scheduler(
538
547
 
539
548
  def init_tokenizer(self):
540
549
  server_args = self.server_args
541
-
542
- self.model_config = ModelConfig.from_server_args(server_args)
543
550
  self.is_generation = self.model_config.is_generation
544
551
 
545
552
  if server_args.skip_tokenizer_init:
@@ -761,6 +768,10 @@ class Scheduler(
761
768
  # The prefill requests that are in the middle of kv sending
762
769
  self.disagg_prefill_inflight_queue: List[Req] = []
763
770
 
771
+ def init_moe_config(self):
772
+ if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
773
+ initialize_moe_config(self.server_args)
774
+
764
775
  @DynamicGradMode()
765
776
  def event_loop_normal(self):
766
777
  """A normal scheduler loop."""
@@ -1133,7 +1144,7 @@ class Scheduler(
1133
1144
  f"boostrap room id. {req.rid=}"
1134
1145
  )
1135
1146
  logger.error(error_msg)
1136
- prepare_abort(req, error_msg)
1147
+ prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
1137
1148
  self.stream_output([req], req.return_logprob)
1138
1149
  return
1139
1150
 
@@ -1823,11 +1834,6 @@ class Scheduler(
1823
1834
  disable_cuda_graph=self.server_args.disable_cuda_graph,
1824
1835
  spec_algorithm=self.spec_algorithm,
1825
1836
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1826
- enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1827
- enable_deepep_moe=MoeA2ABackend(
1828
- self.server_args.moe_a2a_backend
1829
- ).is_deepep(),
1830
- deepep_mode=DeepEPMode(self.server_args.deepep_mode),
1831
1837
  require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1832
1838
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1833
1839
  )
@@ -1922,9 +1928,6 @@ class Scheduler(
1922
1928
  disable_cuda_graph: bool,
1923
1929
  spec_algorithm,
1924
1930
  speculative_num_draft_tokens,
1925
- enable_two_batch_overlap: bool,
1926
- enable_deepep_moe: bool,
1927
- deepep_mode: DeepEPMode,
1928
1931
  require_mlp_tp_gather: bool,
1929
1932
  disable_overlap_schedule: bool,
1930
1933
  ):
@@ -1972,9 +1975,6 @@ class Scheduler(
1972
1975
  is_extend_in_batch,
1973
1976
  *tbo_preparer.prepare_all_gather(
1974
1977
  local_batch,
1975
- deepep_mode,
1976
- enable_deepep_moe,
1977
- enable_two_batch_overlap,
1978
1978
  ),
1979
1979
  ],
1980
1980
  dtype=torch.int64,
@@ -2472,6 +2472,12 @@ class Scheduler(
2472
2472
  if self.idle_sleeper is not None:
2473
2473
  self.idle_sleeper.maybe_sleep()
2474
2474
 
2475
+ def handle_freeze_gc(self, recv_req: FreezeGCReq):
2476
+ """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
2477
+ freeze_gc("Scheduler")
2478
+ self.send_to_detokenizer.send_pyobj(recv_req)
2479
+ return None
2480
+
2475
2481
 
2476
2482
  class IdleSleeper:
2477
2483
  """
@@ -2582,7 +2588,10 @@ def run_scheduler_process(
2582
2588
  if scheduler.enable_overlap:
2583
2589
  scheduler.event_loop_overlap_disagg_prefill()
2584
2590
  else:
2585
- scheduler.event_loop_normal_disagg_prefill()
2591
+ if server_args.pp_size > 1:
2592
+ scheduler.event_loop_pp_disagg_prefill()
2593
+ else:
2594
+ scheduler.event_loop_normal_disagg_prefill()
2586
2595
 
2587
2596
  elif disaggregation_mode == DisaggregationMode.DECODE:
2588
2597
  if scheduler.enable_overlap:
@@ -54,7 +54,7 @@ class SessionReqNode:
54
54
  prefix += " -- " + self.childs[0].req.rid
55
55
  ret = self.childs[0]._str_helper(prefix)
56
56
  for child in self.childs[1:]:
57
- prefix = " " * len(origin_prefix) + " \- " + child.req.rid
57
+ prefix = " " * len(origin_prefix) + " \\- " + child.req.rid
58
58
  ret += child._str_helper(prefix)
59
59
  return ret
60
60
 
@@ -89,6 +89,7 @@ class TemplateManager:
89
89
  if template is None:
90
90
  return False
91
91
 
92
+ # TODO: remove this hard code the reasoning pattern
92
93
  force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
93
94
  has_reasoning = re.search(force_reasoning_pattern, template) is not None
94
95
 
@@ -128,11 +129,12 @@ class TemplateManager:
128
129
  logger.info(
129
130
  f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
130
131
  )
131
- return
132
-
133
- # Default to string content format if no template was found
134
- self._jinja_template_content_format = "string"
135
- logger.info("No chat template found, defaulting to 'string' content format")
132
+ else:
133
+ # Default to string content format if no template was found
134
+ self._jinja_template_content_format = "string"
135
+ logger.info(
136
+ "No chat template found, defaulting to 'string' content format"
137
+ )
136
138
 
137
139
  # Detect reasoning pattern from chat template
138
140
  if tokenizer_manager.tokenizer: