sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
49
49
 
50
50
  TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
51
51
 
52
+ TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
53
+
52
54
 
53
55
  @dataclass
54
56
  class EagleDraftInput:
@@ -177,11 +179,24 @@ class EagleDraftInput:
177
179
  )
178
180
  return kv_indices, cum_kv_seq_len, qo_indptr, None
179
181
 
180
- def filter_batch(self, new_indices: torch.Tensor):
181
- self.topk_p = self.topk_p[: len(new_indices)]
182
- self.topk_index = self.topk_index[: len(new_indices)]
183
- self.hidden_states = self.hidden_states[: len(new_indices)]
184
- self.verified_id = self.verified_id[: len(new_indices)]
182
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
183
+ if has_been_filtered:
184
+ # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
185
+ # therefore, we don't need to filter the batch again in scheduler
186
+ if len(new_indices) != len(self.topk_p):
187
+ logger.warning(
188
+ f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
189
+ )
190
+ self.topk_p = self.topk_p[: len(new_indices)]
191
+ self.topk_index = self.topk_index[: len(new_indices)]
192
+ self.hidden_states = self.hidden_states[: len(new_indices)]
193
+ self.verified_id = self.verified_id[: len(new_indices)]
194
+ else:
195
+ # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
196
+ self.topk_p = self.topk_p[new_indices]
197
+ self.topk_index = self.topk_index[new_indices]
198
+ self.hidden_states = self.hidden_states[new_indices]
199
+ self.verified_id = self.verified_id[new_indices]
185
200
 
186
201
  def merge_batch(self, spec_info: EagleDraftInput):
187
202
  if self.hidden_states is None:
@@ -410,8 +425,15 @@ class EagleVerifyInput:
410
425
  logits=logits_output.next_token_logits, vocab_mask=vocab_mask
411
426
  )
412
427
 
413
- # Sample tokens
414
- if batch.sampling_info.is_all_greedy:
428
+ # Sample tokens. Force greedy sampling on AMD
429
+ is_all_greedy = sampling_info.is_all_greedy
430
+ if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
431
+ logger.warning(
432
+ "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
433
+ "Falling back to greedy verification."
434
+ )
435
+
436
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
415
437
  target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
416
438
  target_predict = target_predict.reshape(bs, self.draft_token_num)
417
439
 
@@ -440,12 +462,13 @@ class EagleVerifyInput:
440
462
  sampling_info.top_ks, self.draft_token_num, dim=0
441
463
  ),
442
464
  ) # (bs * draft_token_num, vocab_size)
443
- target_probs = top_p_renorm_prob(
444
- target_probs,
445
- torch.repeat_interleave(
446
- sampling_info.top_ps, self.draft_token_num, dim=0
447
- ),
448
- )
465
+ if not torch.all(sampling_info.top_ps == 1.0):
466
+ target_probs = top_p_renorm_prob(
467
+ target_probs,
468
+ torch.repeat_interleave(
469
+ sampling_info.top_ps, self.draft_token_num, dim=0
470
+ ),
471
+ )
449
472
  target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
450
473
 
451
474
  draft_probs = torch.zeros(
@@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download
9
9
 
10
10
  from sglang.srt.distributed import (
11
11
  GroupCoordinator,
12
- get_tensor_model_parallel_world_size,
13
12
  get_tp_group,
14
13
  patch_tensor_parallel_group,
15
14
  )
@@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker):
92
91
  )
93
92
  self.padded_static_len = -1
94
93
 
95
- # Override context length with target model's context length
94
+ # Override the context length of the draft model to be the same as the target model.
96
95
  server_args.context_length = target_worker.model_runner.model_config.context_len
97
96
 
98
97
  # Do not capture cuda graph in `super().__init__()`
@@ -267,6 +266,43 @@ class EAGLEWorker(TpModelWorker):
267
266
  self.topk,
268
267
  self.speculative_num_steps,
269
268
  )
269
+ elif self.server_args.attention_backend == "trtllm_mha":
270
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
271
+ TRTLLMHAAttnBackend,
272
+ TRTLLMHAAttnMultiStepDraftBackend,
273
+ )
274
+
275
+ self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
276
+ self.draft_model_runner,
277
+ self.topk,
278
+ self.speculative_num_steps,
279
+ )
280
+ self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
281
+ self.draft_model_runner,
282
+ skip_prefill=False,
283
+ )
284
+ self.has_prefill_wrapper_verify = True
285
+ elif self.server_args.attention_backend == "trtllm_mla":
286
+ if not global_server_args_dict["use_mla_backend"]:
287
+ raise ValueError(
288
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
289
+ )
290
+
291
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
292
+ TRTLLMMLABackend,
293
+ TRTLLMMLAMultiStepDraftBackend,
294
+ )
295
+
296
+ self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
297
+ self.draft_model_runner,
298
+ self.topk,
299
+ self.speculative_num_steps,
300
+ )
301
+ self.draft_extend_attn_backend = TRTLLMMLABackend(
302
+ self.draft_model_runner,
303
+ skip_prefill=False,
304
+ )
305
+ self.has_prefill_wrapper_verify = True
270
306
  else:
271
307
  raise ValueError(
272
308
  f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
@@ -836,6 +872,21 @@ class EAGLEWorker(TpModelWorker):
836
872
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
837
873
  assert forward_batch.spec_info is batch.spec_info
838
874
  self.capture_for_decode(logits_output, forward_batch.spec_info)
875
+ has_finished, unfinished_req_index = False, []
876
+ for i, req in enumerate(batch.reqs):
877
+ if req.finished():
878
+ has_finished = True
879
+ else:
880
+ unfinished_req_index.append(i)
881
+ if has_finished:
882
+ unfinished_index_device = torch.tensor(
883
+ unfinished_req_index,
884
+ dtype=torch.int64,
885
+ device=batch.spec_info.topk_p.device,
886
+ )
887
+ batch.spec_info.filter_batch(
888
+ unfinished_index_device, has_been_filtered=False
889
+ )
839
890
 
840
891
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
841
892
  assert isinstance(batch.spec_info, EagleDraftInput)
@@ -966,7 +1017,9 @@ def get_last_loc_large_page_size_top_k_1(
966
1017
  return prefix_lens, seq_lens, last_loc
967
1018
 
968
1019
 
969
- @torch.compile(dynamic=True)
1020
+ # Disable torch.compile for this function because it will be
1021
+ # even slower.
1022
+ # @torch.compile(dynamic=True)
970
1023
  def get_last_loc_large_page_size_large_top_k(
971
1024
  req_to_token: torch.Tensor,
972
1025
  req_pool_indices: torch.Tensor,
@@ -0,0 +1,161 @@
1
+ import functools
2
+ import json
3
+ from typing import AbstractSet, Collection, List, Literal, Union
4
+
5
+
6
+ class TiktokenProcessor:
7
+ def __init__(self, name: str):
8
+ self.tokenizer = TiktokenTokenizer(name)
9
+
10
+ def image_processor(self, image):
11
+ return {"pixel_values": [image]}
12
+
13
+
14
+ RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
15
+ CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
16
+
17
+
18
+ PAD = "<|pad|>"
19
+ EOS = "<|eos|>"
20
+ SEP = "<|separator|>"
21
+
22
+ DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
23
+ DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
24
+
25
+ # default + separate each single digit
26
+ PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
27
+
28
+
29
+ class TiktokenTokenizer:
30
+ def __init__(self, tokenizer_path):
31
+ import tiktoken
32
+ from jinja2 import Template
33
+
34
+ # Read the JSON
35
+ with open(tokenizer_path, "rb") as fin:
36
+ xtok_dict = json.load(fin)
37
+
38
+ # Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
39
+ mergeable_ranks = {
40
+ bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"]
41
+ }
42
+ special_tokens = {
43
+ bytes(item["bytes"]).decode(): item["token"]
44
+ for item in xtok_dict["special_tokens"]
45
+ }
46
+ if xtok_dict["word_split"] == "V1":
47
+ pad_str = PAT_STR_B
48
+ else:
49
+ assert False, f"Unknown word_split: {xtok_dict['word_split']}"
50
+ pad_str = xtok_dict.get("pat_str", pad_str)
51
+
52
+ kwargs = {
53
+ "name": tokenizer_path,
54
+ "pat_str": pad_str,
55
+ "mergeable_ranks": mergeable_ranks,
56
+ "special_tokens": special_tokens,
57
+ }
58
+ if "default_allowed_special" in xtok_dict:
59
+ default_allowed_special = set(
60
+ [
61
+ bytes(bytes_list).decode()
62
+ for bytes_list in xtok_dict["default_allowed_special"]
63
+ ]
64
+ )
65
+ if "vocab_size" in xtok_dict:
66
+ kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
67
+
68
+ # Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
69
+ default_allowed_special = None
70
+ control_tokens = DEFAULT_CONTROL_TOKENS
71
+ tokenizer = tiktoken.Encoding(**kwargs)
72
+ tokenizer._default_allowed_special = default_allowed_special or set()
73
+ tokenizer._control_tokens = control_tokens
74
+
75
+ def encode_patched(
76
+ self,
77
+ text: str,
78
+ *,
79
+ allowed_special: Union[
80
+ Literal["all"], AbstractSet[str]
81
+ ] = set(), # noqa: B006
82
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
83
+ ) -> List[int]:
84
+ if isinstance(allowed_special, set):
85
+ allowed_special |= self._default_allowed_special
86
+ return tiktoken.Encoding.encode(
87
+ self,
88
+ text,
89
+ allowed_special=allowed_special,
90
+ disallowed_special=(),
91
+ )
92
+
93
+ tokenizer.encode = functools.partial(encode_patched, tokenizer)
94
+
95
+ # Allow more tokens to prevent crash
96
+ tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
97
+ tokenizer._default_allowed_special |= set(
98
+ CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
99
+ )
100
+
101
+ # Convert to HF interface
102
+ self.tokenizer = tokenizer
103
+ self.bos_token_id = None
104
+ self.eos_token_id = tokenizer._special_tokens[EOS]
105
+ self.vocab_size = tokenizer.n_vocab
106
+ self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
107
+ self.chat_template_jinja = Template(self.chat_template)
108
+ self.additional_stop_token_ids = None
109
+
110
+ def encode(self, x, add_special_tokens=False):
111
+ return self.tokenizer.encode(x)
112
+
113
+ def decode(self, x, *args, **kwargs):
114
+ return self.tokenizer.decode(x)
115
+
116
+ def batch_decode(
117
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
118
+ ):
119
+ if len(batch) > 0 and isinstance(batch[0], int):
120
+ batch = [[x] for x in batch]
121
+ return self.tokenizer.decode_batch(batch)
122
+
123
+ def apply_chat_template(
124
+ self, messages, tokenize, add_generation_prompt, tools=None
125
+ ):
126
+ ret = self.chat_template_jinja.render(
127
+ messages=messages, add_generation_prompt=add_generation_prompt
128
+ )
129
+ return self.encode(ret) if tokenize else ret
130
+
131
+ def __call__(self, text, **kwargs):
132
+ return {
133
+ "input_ids": self.encode(text),
134
+ }
135
+
136
+ def init_xgrammar(self):
137
+ from xgrammar import TokenizerInfo
138
+
139
+ XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>"
140
+
141
+ enc = self.tokenizer
142
+ encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens}
143
+ encoded_vocab = [
144
+ token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
145
+ ]
146
+ override_stop_tokens = [2] # eos
147
+ # These are treated as special tokens in xgrammar; we want to avoid them
148
+ # For now, xgrammar treats anything starting with b'\x00' as a special token
149
+ xgrammar_special_token_ids = []
150
+ for i, token in enumerate(encoded_vocab):
151
+ if isinstance(token, bytes) and token.startswith(b"\x00"):
152
+ xgrammar_special_token_ids.append(i)
153
+
154
+ for i, id in enumerate(xgrammar_special_token_ids):
155
+ encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i)
156
+ tokenizer_info = TokenizerInfo(
157
+ encoded_vocab, stop_token_ids=override_stop_tokens
158
+ )
159
+ assert len(tokenizer_info.special_token_ids) == 0
160
+
161
+ return tokenizer_info, override_stop_tokens
@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
14
14
  CommunicateSummableTensorPairFn,
15
15
  ScatterMode,
16
16
  )
17
+ from sglang.srt.layers.moe import (
18
+ get_deepep_mode,
19
+ get_moe_a2a_backend,
20
+ get_tbo_token_distribution_threshold,
21
+ is_tbo_enabled,
22
+ )
17
23
  from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
18
- from sglang.srt.layers.moe.utils import DeepEPMode
19
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
20
25
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
21
26
  from sglang.srt.model_executor.forward_batch_info import (
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
83
88
  vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
84
89
  left_sum = sum(extend_lens[:vanilla_split_seq_index])
85
90
  overall_sum = sum(extend_lens)
86
- threshold = global_server_args_dict["tbo_token_distribution_threshold"]
91
+ threshold = get_tbo_token_distribution_threshold()
87
92
  assert threshold <= 0.5, f"{threshold=}"
88
93
  return left_sum < overall_sum * threshold or left_sum > overall_sum * (
89
94
  1 - threshold
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
299
304
  self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
300
305
 
301
306
  def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
302
- if not global_server_args_dict["enable_two_batch_overlap"]:
307
+ if not is_tbo_enabled():
303
308
  return
304
309
  token_num_per_seq = get_token_num_per_seq(
305
310
  forward_mode=batch.forward_mode, spec_info=batch.spec_info
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
353
358
  def prepare_all_gather(
354
359
  self,
355
360
  local_batch: ScheduleBatch,
356
- deepep_mode: DeepEPMode,
357
- enable_deepep_moe: bool,
358
- enable_two_batch_overlap: bool,
359
361
  ):
362
+
363
+ deepep_mode = get_deepep_mode()
364
+ enable_deepep_moe = get_moe_a2a_backend().is_deepep()
365
+ enable_two_batch_overlap = is_tbo_enabled()
366
+
360
367
  self.enable_two_batch_overlap = enable_two_batch_overlap
361
368
 
362
369
  if local_batch is not None:
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
384
391
  and not local_batch.forward_mode.is_target_verify()
385
392
  )
386
393
  and enable_deepep_moe
387
- and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
394
+ and (resolved_deepep_mode.is_low_latency())
388
395
  )
389
396
  else:
390
397
  self.local_tbo_split_seq_index = 0
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
657
664
  "req_to_token_pool",
658
665
  "token_to_kv_pool",
659
666
  "can_run_dp_cuda_graph",
667
+ "dp_padding_mode",
660
668
  "global_forward_mode",
661
669
  "spec_algorithm",
662
670
  "capture_hidden_mode",
@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
701
709
  tbo_children=None,
702
710
  global_num_tokens_gpu=None,
703
711
  global_num_tokens_cpu=None,
704
- dp_padding_mode=None,
705
712
  global_dp_buffer_len=global_dp_buffer_len,
706
713
  global_num_tokens_for_logprob_gpu=None,
707
714
  global_num_tokens_for_logprob_cpu=None,
@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
955
962
 
956
963
  class MaybeTboDeepEPDispatcher:
957
964
  def __init__(self, **kwargs):
958
- num_inner_dispatchers = (
959
- 2 if global_server_args_dict["enable_two_batch_overlap"] else 1
960
- )
965
+ num_inner_dispatchers = 2 if is_tbo_enabled() else 1
961
966
  self._inners = [
962
967
  DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
963
968
  ]
sglang/srt/utils.py CHANGED
@@ -438,70 +438,6 @@ def is_pin_memory_available() -> bool:
438
438
  return torch.cuda.is_available()
439
439
 
440
440
 
441
- _CPU_OFFLOAD_BYTES = 0
442
- _CPU_OFFLOAD_MAX_BYTES = 0
443
-
444
-
445
- def set_cpu_offload_max_bytes(max_bytes: int) -> None:
446
- global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
447
- _CPU_OFFLOAD_BYTES = 0
448
- _CPU_OFFLOAD_MAX_BYTES = max_bytes
449
-
450
-
451
- def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
452
- device = next(module.parameters()).device
453
-
454
- if device == torch.device("cpu"):
455
- return module
456
-
457
- global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
458
- if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
459
- return module
460
-
461
- pin_memory = is_pin_memory_available()
462
- # offload parameters to CPU
463
- # use pin_memory if possible, which helps cudagraph capture speed
464
- offloaded_parameters = False
465
- for p in module.parameters():
466
- if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
467
- # we use per-parameter offloading
468
- # one module might have some parameters offloaded and some not
469
- break
470
-
471
- # `torch.empty_like` does not support `pin_memory` argument
472
- cpu_data = torch.empty_strided(
473
- size=p.data.size(),
474
- stride=p.data.stride(),
475
- dtype=p.data.dtype,
476
- layout=p.data.layout,
477
- device="cpu",
478
- pin_memory=pin_memory,
479
- )
480
- cpu_data.copy_(p.data)
481
- p.data = cpu_data
482
- _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
483
- offloaded_parameters = True
484
-
485
- if offloaded_parameters:
486
- original_forward = module.forward
487
-
488
- def forward(*args, **kwargs):
489
- module.forward = original_forward
490
- device_state = {
491
- # here we blindly call `to(device)`
492
- # if the parameter is already on the device, it will be a no-op
493
- k: v.to(device, non_blocking=True)
494
- for k, v in module.state_dict().items()
495
- }
496
- output = functional_call(module, device_state, args=args, kwargs=kwargs)
497
- module.forward = forward
498
- return output
499
-
500
- module.forward = forward
501
-
502
- return module
503
-
504
-
505
441
  class LayerFn(Protocol):
506
442
 
507
443
  def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
@@ -514,11 +450,13 @@ def make_layers(
514
450
  pp_size: Optional[int] = None,
515
451
  prefix: str = "",
516
452
  return_tuple: bool = False,
453
+ offloader_kwargs: Dict[str, Any] = {},
517
454
  ) -> Tuple[int, int, torch.nn.ModuleList]:
518
455
  """Make a list of layers with the given layer function"""
519
456
  # circula imports
520
457
  from sglang.srt.distributed import get_pp_indices
521
458
  from sglang.srt.layers.utils import PPMissingLayer
459
+ from sglang.srt.offloader import get_offloader
522
460
 
523
461
  assert not pp_size or num_hidden_layers >= pp_size
524
462
  start_layer, end_layer = (
@@ -532,10 +470,13 @@ def make_layers(
532
470
  )
533
471
  modules = torch.nn.ModuleList(
534
472
  [PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
535
- + [
536
- maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
537
- for idx in range(start_layer, end_layer)
538
- ]
473
+ + get_offloader().wrap_modules(
474
+ (
475
+ layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
476
+ for idx in range(start_layer, end_layer)
477
+ ),
478
+ **offloader_kwargs,
479
+ )
539
480
  + [
540
481
  PPMissingLayer(return_tuple=return_tuple)
541
482
  for _ in range(end_layer, num_hidden_layers)
@@ -2343,6 +2284,7 @@ def is_fa3_default_architecture(hf_config):
2343
2284
  "Qwen3ForCausalLM",
2344
2285
  "Qwen3MoeForCausalLM",
2345
2286
  "Glm4MoeForCausalLM",
2287
+ "Glm4vMoeForConditionalGeneration",
2346
2288
  "Step3VLForConditionalGeneration",
2347
2289
  }
2348
2290
  return architectures[0] in default_archs
@@ -2413,7 +2355,7 @@ def require_mlp_tp_gather(server_args):
2413
2355
  return True
2414
2356
  elif not server_args.enable_dp_lm_head:
2415
2357
  return True
2416
- elif server_args.moe_a2a_backend is None:
2358
+ elif server_args.moe_a2a_backend == "none":
2417
2359
  return True
2418
2360
  else:
2419
2361
  return (
@@ -2429,7 +2371,7 @@ def require_attn_tp_gather(server_args):
2429
2371
  Check if the input of attention is scattered.
2430
2372
  """
2431
2373
  assert server_args.moe_dense_tp_size in [1, None]
2432
- if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
2374
+ if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
2433
2375
  if server_args.enable_dp_attention:
2434
2376
  return server_args.dp_size < server_args.tp_size
2435
2377
  else:
@@ -2599,6 +2541,50 @@ def dynamic_import(func_path: str):
2599
2541
  return func
2600
2542
 
2601
2543
 
2544
+ def gc_object_counts():
2545
+ import gc
2546
+
2547
+ g0 = len(gc.get_objects(0))
2548
+ g1 = len(gc.get_objects(1))
2549
+ g2 = len(gc.get_objects(2))
2550
+ return g0, g1, g2
2551
+
2552
+
2553
+ def configure_gc_warning(warn_threshold_secs):
2554
+ import gc
2555
+
2556
+ gc_start_time = {}
2557
+
2558
+ def gc_callback(phase, info):
2559
+ gen = info.get("generation", "?")
2560
+ if phase == "start":
2561
+ gc_start_time[gen] = time.time()
2562
+ elif phase == "stop":
2563
+ duration = time.time() - gc_start_time.get(gen, time.time())
2564
+ if duration > warn_threshold_secs:
2565
+ g0, g1, g2 = gc_object_counts()
2566
+ logger.warn(
2567
+ f"LONG GARBAGE COLLECTION DETECTED | Generation {gen} | Duration: {duration:.4f}s | # Objects: gen0={g0}, gen1={g1}, gen2={g2} | "
2568
+ f"This may cause latency jitter. Consider calling the freeze_gc API after sending a few warmup requests."
2569
+ )
2570
+
2571
+ gc.callbacks.append(gc_callback)
2572
+
2573
+
2574
+ def freeze_gc(context: str):
2575
+ import gc
2576
+
2577
+ g0_before, g1_before, g2_before = gc_object_counts()
2578
+ gc.freeze()
2579
+ g0_after, g1_after, g2_after = gc_object_counts()
2580
+ logger.info(
2581
+ f"Freezing GC in {context} process. "
2582
+ f"gen0: {g0_before}->{g0_after}, "
2583
+ f"gen1: {g1_before}->{g1_after}, "
2584
+ f"gen2: {g2_before}->{g2_after}"
2585
+ )
2586
+
2587
+
2602
2588
  def configure_gc_logger():
2603
2589
  logger.info("Enable GC Logger")
2604
2590
 
@@ -2872,6 +2858,8 @@ SUPPORTED_LORA_TARGET_MODULES = [
2872
2858
  "gate_proj",
2873
2859
  "up_proj",
2874
2860
  "down_proj",
2861
+ "qkv_proj",
2862
+ "gate_up_proj",
2875
2863
  ]
2876
2864
 
2877
2865
  LORA_TARGET_ALL_MODULES = "all"
@@ -2966,3 +2954,13 @@ class ConcurrentCounter:
2966
2954
  @lru_cache(maxsize=1)
2967
2955
  def is_triton_kernels_available() -> bool:
2968
2956
  return importlib.util.find_spec("triton_kernels") is not None
2957
+
2958
+
2959
+ def check_cuda_result(raw_output):
2960
+ import cuda.bindings.runtime as cuda_rt
2961
+
2962
+ err, *results = raw_output
2963
+ if err != cuda_rt.cudaError_t.cudaSuccess:
2964
+ raise Exception(f"CUDA error: {err}")
2965
+
2966
+ return results
sglang/test/runners.py CHANGED
@@ -231,11 +231,14 @@ class HFRunner:
231
231
 
232
232
  # Load the model and tokenizer
233
233
  if self.model_type == "generation":
234
- config = AutoConfig.from_pretrained(model_path)
235
- if model_archs := getattr(config, "architectures"):
236
- model_cls = getattr(transformers, model_archs[0])
237
- else:
234
+ config = AutoConfig.from_pretrained(
235
+ model_path, trust_remote_code=self.trust_remote_code
236
+ )
237
+ if self.trust_remote_code:
238
238
  model_cls = AutoModelForCausalLM
239
+ else:
240
+ model_arch = getattr(config, "architectures")[0]
241
+ model_cls = getattr(transformers, model_arch)
239
242
  self.base_model = model_cls.from_pretrained(
240
243
  model_path,
241
244
  torch_dtype=torch_dtype,
@@ -488,7 +491,7 @@ class SRTRunner:
488
491
  tp_size: int = 1,
489
492
  model_impl: str = "auto",
490
493
  port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
491
- lora_paths: List[str] = None,
494
+ lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None,
492
495
  max_loras_per_batch: int = 4,
493
496
  attention_backend: Optional[str] = None,
494
497
  prefill_attention_backend: Optional[str] = None,
@@ -6,7 +6,7 @@ import torch
6
6
 
7
7
  from sglang.srt.layers.activation import SiluAndMul
8
8
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
9
- from sglang.srt.layers.moe.topk import select_experts
9
+ from sglang.srt.layers.moe.topk import TopKConfig, select_experts
10
10
  from sglang.srt.layers.quantization.fp8_kernel import (
11
11
  per_tensor_quant_mla_fp8,
12
12
  per_token_group_quant_fp8,
@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
498
498
  score = torch.randn((M, E), dtype=dtype)
499
499
 
500
500
  with torch.inference_mode():
501
+ ref_out = torch_w8a8_block_fp8_moe(
502
+ a, w1, w2, w1_s, w2_s, score, topk, block_size
503
+ )
501
504
  topk_output = select_experts(
502
505
  hidden_states=a,
503
506
  router_logits=score,
504
- top_k=topk,
505
- renormalize=False,
507
+ topk_config=TopKConfig(top_k=topk, renormalize=False),
506
508
  )
507
509
  out = fused_moe(
508
510
  a,
@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
514
516
  w2_scale=w2_s,
515
517
  block_shape=block_size,
516
518
  )
517
- ref_out = torch_w8a8_block_fp8_moe(
518
- a, w1, w2, w1_s, w2_s, score, topk, block_size
519
- )
520
519
 
521
520
  self.assertTrue(
522
521
  torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))