sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import threading
38
38
  from enum import Enum, auto
39
39
  from http import HTTPStatus
40
40
  from itertools import chain
41
- from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
41
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
42
42
 
43
43
  import numpy as np
44
44
  import torch
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
52
  ScheduleBatchDisaggregationDecodeMixin,
53
53
  )
54
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
- from sglang.srt.layers.moe import is_tbo_enabled
56
55
  from sglang.srt.mem_cache.allocator import (
57
56
  BaseTokenToKVPoolAllocator,
58
57
  SWATokenToKVPoolAllocator,
@@ -60,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
60
59
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
61
60
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
62
61
  from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
63
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
62
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
64
63
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
65
64
  from sglang.srt.metrics.collector import TimeStats
66
65
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -99,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
99
98
  "sampling_backend",
100
99
  "speculative_accept_threshold_single",
101
100
  "speculative_accept_threshold_acc",
101
+ "speculative_attention_mode",
102
102
  "torchao_config",
103
103
  "triton_attention_reduce_in_fp32",
104
104
  "num_reserved_decode_tokens",
@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
911
911
  is_prefill_only: bool = False
912
912
 
913
913
  # hicache pointer for synchronizing data loading from CPU to GPU
914
- hicache_consumer_index: int = 0
914
+ hicache_consumer_index: int = -1
915
915
 
916
916
  @classmethod
917
917
  def init_new(
@@ -962,8 +962,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
962
962
  def is_empty(self):
963
963
  return len(self.reqs) == 0
964
964
 
965
- def alloc_req_slots(self, num_reqs: int):
966
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
965
+ def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
966
+ if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
967
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
968
+ else:
969
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
967
970
  if req_pool_indices is None:
968
971
  raise RuntimeError(
969
972
  "alloc_req_slots runs out of memory. "
@@ -1138,7 +1141,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1138
1141
 
1139
1142
  # Allocate req slots
1140
1143
  bs = len(self.reqs)
1141
- req_pool_indices = self.alloc_req_slots(bs)
1144
+ req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1142
1145
 
1143
1146
  # Init tensors
1144
1147
  reqs = self.reqs
@@ -1372,21 +1375,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1372
1375
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1373
1376
  self.extend_logprob_start_lens.extend([0] * running_bs)
1374
1377
 
1375
- def new_page_count_next_decode(self):
1378
+ def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1376
1379
  page_size = self.token_to_kv_pool_allocator.page_size
1380
+ requests = (
1381
+ self.reqs
1382
+ if selected_indices is None
1383
+ else [self.reqs[i] for i in selected_indices]
1384
+ )
1377
1385
  if page_size == 1:
1378
- return len(self.reqs)
1386
+ return len(requests)
1379
1387
  # In the decoding phase, the length of a request's KV cache should be
1380
1388
  # the total length of the request minus 1
1381
1389
  return (
1382
- sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1390
+ sum(1 for req in requests if req.seqlen % page_size == 0)
1383
1391
  if self.enable_overlap
1384
- else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1392
+ else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
1385
1393
  )
1386
1394
 
1387
- def check_decode_mem(self, buf_multiplier=1):
1395
+ def check_decode_mem(
1396
+ self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
1397
+ ):
1388
1398
  num_tokens = (
1389
- self.new_page_count_next_decode()
1399
+ self.new_page_count_next_decode(selected_indices)
1390
1400
  * buf_multiplier
1391
1401
  * self.token_to_kv_pool_allocator.page_size
1392
1402
  )
@@ -1412,34 +1422,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1412
1422
  reverse=True,
1413
1423
  )
1414
1424
 
1415
- def get_required_tokens(num_reqs: int):
1416
- headroom_for_spec_decode = 0
1417
- if server_args.speculative_algorithm:
1418
- headroom_for_spec_decode += (
1419
- num_reqs
1420
- * server_args.speculative_eagle_topk
1421
- * server_args.speculative_num_steps
1422
- + num_reqs * server_args.speculative_num_draft_tokens
1423
- )
1424
- return (
1425
- num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
1426
- )
1427
-
1428
- def _get_available_size():
1429
- if self.is_hybrid:
1430
- return min(
1431
- self.token_to_kv_pool_allocator.full_available_size(),
1432
- self.token_to_kv_pool_allocator.swa_available_size(),
1433
- )
1434
- else:
1435
- return self.token_to_kv_pool_allocator.available_size()
1436
-
1437
1425
  retracted_reqs = []
1438
1426
  seq_lens_cpu = self.seq_lens.cpu().numpy()
1439
1427
  first_iter = True
1440
- while (
1441
- _get_available_size() < get_required_tokens(len(sorted_indices))
1442
- or first_iter
1428
+ while first_iter or (
1429
+ not self.check_decode_mem(selected_indices=sorted_indices)
1443
1430
  ):
1444
1431
  if len(sorted_indices) == 1:
1445
1432
  # Corner case: only one request left
@@ -1493,10 +1480,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1493
1480
  else:
1494
1481
  self.tree_cache.dec_lock_ref(req.last_node)
1495
1482
 
1496
- # NOTE(lsyin): we should use the newly evictable memory instantly.
1497
- num_tokens = len(sorted_indices) * global_config.retract_decode_steps
1498
- self._evict_tree_cache_if_needed(num_tokens)
1499
-
1500
1483
  req.reset_for_retract()
1501
1484
 
1502
1485
  if len(retracted_reqs) == 0:
@@ -1540,7 +1523,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1540
1523
  self.forward_mode = ForwardMode.DECODE
1541
1524
  bs = len(self.reqs)
1542
1525
 
1543
- if self.spec_algorithm.is_eagle():
1526
+ if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1544
1527
  # if spec decoding is used, the decode batch is prepared inside
1545
1528
  # `forward_batch_speculative_generation` after running draft models.
1546
1529
  return
@@ -1917,7 +1900,7 @@ class ModelWorkerBatch:
1917
1900
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1918
1901
  # If set, the output of the batch contains the hidden states of the run.
1919
1902
  capture_hidden_mode: CaptureHiddenMode = None
1920
- hicache_consumer_index: int = 0
1903
+ hicache_consumer_index: int = -1
1921
1904
 
1922
1905
  # Overlap event
1923
1906
  launch_done: Optional[threading.Event] = None
@@ -380,8 +380,9 @@ class PrefillAdder:
380
380
  self.log_input_tokens += extend_input_len
381
381
 
382
382
  def add_chunked_req(self, req: Req):
383
- truncated = req.extend_input_len > self.rem_chunk_tokens
384
- req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
383
+ _rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
384
+ truncated = req.extend_input_len > _rem_tokens
385
+ req.extend_input_len = min(req.extend_input_len, _rem_tokens)
385
386
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
386
387
  self.can_run_list.append(req)
387
388
  self._update_prefill_budget(
@@ -549,7 +550,7 @@ class PrefillAdder:
549
550
  )
550
551
  else:
551
552
  # Make sure at least one page is available
552
- trunc_len = self.rem_chunk_tokens - self.page_size + 1
553
+ trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
553
554
  if trunc_len <= 0:
554
555
  return AddReqResult.OTHER
555
556
 
@@ -69,6 +69,8 @@ from sglang.srt.managers.io_struct import (
69
69
  AbortReq,
70
70
  BatchTokenizedEmbeddingReqInput,
71
71
  BatchTokenizedGenerateReqInput,
72
+ ClearHiCacheReqInput,
73
+ ClearHiCacheReqOutput,
72
74
  CloseSessionReqInput,
73
75
  ExpertDistributionReq,
74
76
  ExpertDistributionReqOutput,
@@ -82,6 +84,8 @@ from sglang.srt.managers.io_struct import (
82
84
  InitWeightsUpdateGroupReqInput,
83
85
  LoadLoRAAdapterReqInput,
84
86
  LoadLoRAAdapterReqOutput,
87
+ MultiTokenizerRegisterReq,
88
+ MultiTokenizerWrapper,
85
89
  OpenSessionReqInput,
86
90
  OpenSessionReqOutput,
87
91
  ProfileReq,
@@ -137,7 +141,7 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
137
141
  from sglang.srt.mem_cache.radix_cache import RadixCache
138
142
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
139
143
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
140
- from sglang.srt.reasoning_parser import ReasoningParser
144
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
141
145
  from sglang.srt.server_args import PortArgs, ServerArgs
142
146
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
143
147
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -154,6 +158,7 @@ from sglang.srt.utils import (
154
158
  get_zmq_socket,
155
159
  is_cpu,
156
160
  kill_itself_when_parent_died,
161
+ numa_bind_to_node,
157
162
  point_to_point_pyobj,
158
163
  pyspy_dump_schedulers,
159
164
  require_mlp_sync,
@@ -255,7 +260,6 @@ class Scheduler(
255
260
  # Init inter-process communication
256
261
  context = zmq.Context(2)
257
262
  self.idle_sleeper = None
258
-
259
263
  if self.pp_rank == 0 and self.attn_tp_rank == 0:
260
264
  self.recv_from_tokenizer = get_zmq_socket(
261
265
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
@@ -345,6 +349,18 @@ class Scheduler(
345
349
  target_worker=self.tp_worker,
346
350
  dp_rank=dp_rank,
347
351
  )
352
+ elif self.spec_algorithm.is_standalone():
353
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
354
+
355
+ self.draft_worker = StandaloneWorker(
356
+ gpu_id=gpu_id,
357
+ tp_rank=tp_rank,
358
+ moe_ep_rank=moe_ep_rank,
359
+ server_args=server_args,
360
+ nccl_port=port_args.nccl_port,
361
+ target_worker=self.tp_worker,
362
+ dp_rank=dp_rank,
363
+ )
348
364
  else:
349
365
  self.draft_worker = None
350
366
 
@@ -398,7 +414,7 @@ class Scheduler(
398
414
  f"max_prefill_tokens={self.max_prefill_tokens}, "
399
415
  f"max_running_requests={self.max_running_requests}, "
400
416
  f"context_len={self.model_config.context_len}, "
401
- f"available_gpu_mem={avail_mem:.2f} GB"
417
+ f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
402
418
  )
403
419
 
404
420
  # Init memory pool and cache
@@ -485,7 +501,7 @@ class Scheduler(
485
501
  enable=server_args.enable_memory_saver
486
502
  )
487
503
  self.offload_tags = set()
488
- self.init_profier()
504
+ self.init_profiler()
489
505
 
490
506
  self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
491
507
  self.input_blocker = (
@@ -497,6 +513,7 @@ class Scheduler(
497
513
  # Init metrics stats
498
514
  self.init_metrics(tp_rank, pp_rank, dp_rank)
499
515
  self.init_kv_events(server_args.kv_events_config)
516
+ self.init_dp_balance(dp_balance_meta)
500
517
 
501
518
  # Init disaggregation
502
519
  self.disaggregation_mode = DisaggregationMode(
@@ -515,6 +532,7 @@ class Scheduler(
515
532
  (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
516
533
  (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
517
534
  (FlushCacheReqInput, self.flush_cache_wrapped),
535
+ (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
518
536
  (AbortReq, self.abort_request),
519
537
  (OpenSessionReqInput, self.open_session),
520
538
  (CloseSessionReqInput, self.close_session),
@@ -537,18 +555,10 @@ class Scheduler(
537
555
  (ExpertDistributionReq, self.expert_distribution_handle),
538
556
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
539
557
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
558
+ (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
540
559
  ]
541
560
  )
542
561
 
543
- self.balance_meta = dp_balance_meta
544
- if (
545
- server_args.enable_dp_attention
546
- and server_args.load_balance_method == "minimum_tokens"
547
- ):
548
- assert dp_balance_meta is not None
549
-
550
- self.recv_dp_balance_id_this_term = []
551
-
552
562
  def init_tokenizer(self):
553
563
  server_args = self.server_args
554
564
  self.is_generation = self.model_config.is_generation
@@ -625,6 +635,7 @@ class Scheduler(
625
635
  hicache_write_policy=server_args.hicache_write_policy,
626
636
  hicache_io_backend=server_args.hicache_io_backend,
627
637
  hicache_mem_layout=server_args.hicache_mem_layout,
638
+ enable_metrics=self.enable_metrics,
628
639
  hicache_storage_backend=server_args.hicache_storage_backend,
629
640
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
630
641
  model_name=server_args.served_model_name,
@@ -657,6 +668,21 @@ class Scheduler(
657
668
  page_size=self.page_size,
658
669
  disable=server_args.disable_radix_cache,
659
670
  )
671
+ elif server_args.enable_lmcache:
672
+ from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
673
+ LMCRadixCache,
674
+ )
675
+
676
+ self.tree_cache = LMCRadixCache(
677
+ req_to_token_pool=self.req_to_token_pool,
678
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
679
+ page_size=self.page_size,
680
+ disable=server_args.disable_radix_cache,
681
+ model_config=self.model_config,
682
+ tp_size=self.tp_size,
683
+ rank=self.tp_rank,
684
+ tp_group=self.tp_group,
685
+ )
660
686
  else:
661
687
  self.tree_cache = RadixCache(
662
688
  req_to_token_pool=self.req_to_token_pool,
@@ -1098,6 +1124,17 @@ class Scheduler(
1098
1124
  )
1099
1125
  self.send_to_tokenizer.send_pyobj(abort_req)
1100
1126
  continue
1127
+
1128
+ # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1129
+ if isinstance(recv_req, MultiTokenizerWrapper):
1130
+ worker_id = recv_req.worker_id
1131
+ recv_req = recv_req.obj
1132
+ output = self._request_dispatcher(recv_req)
1133
+ if output is not None:
1134
+ output = MultiTokenizerWrapper(worker_id, output)
1135
+ self.send_to_tokenizer.send_pyobj(output)
1136
+ continue
1137
+
1101
1138
  output = self._request_dispatcher(recv_req)
1102
1139
  if output is not None:
1103
1140
  if isinstance(output, RpcReqOutput):
@@ -1110,11 +1147,7 @@ class Scheduler(
1110
1147
  self,
1111
1148
  recv_req: TokenizedGenerateReqInput,
1112
1149
  ):
1113
- if (
1114
- self.server_args.enable_dp_attention
1115
- and self.server_args.load_balance_method == "minimum_tokens"
1116
- ):
1117
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1150
+ self.maybe_update_dp_balance_data(recv_req)
1118
1151
 
1119
1152
  # Create a new request
1120
1153
  if (
@@ -1405,9 +1438,11 @@ class Scheduler(
1405
1438
  _, _, available_size, evictable_size = self._get_token_info()
1406
1439
  protected_size = self.tree_cache.protected_size()
1407
1440
  memory_leak = (available_size + evictable_size) != (
1441
+ # self.max_total_num_tokens
1442
+ # if not self.enable_hierarchical_cache
1443
+ # else self.max_total_num_tokens - protected_size
1408
1444
  self.max_total_num_tokens
1409
- if not self.enable_hierarchical_cache
1410
- else self.max_total_num_tokens - protected_size
1445
+ - protected_size
1411
1446
  )
1412
1447
  token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1413
1448
 
@@ -1503,9 +1538,14 @@ class Scheduler(
1503
1538
  # Move the chunked request out of the batch so that we can merge
1504
1539
  # only finished requests to running_batch.
1505
1540
  chunked_req_to_exclude.add(self.chunked_req)
1506
- self.tree_cache.cache_unfinished_req(self.chunked_req)
1541
+ self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1507
1542
  # chunked request keeps its rid but will get a new req_pool_idx
1508
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1543
+ if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1544
+ self.req_to_token_pool.free(
1545
+ self.chunked_req.req_pool_idx, free_mamba_cache=False
1546
+ )
1547
+ else:
1548
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1509
1549
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1510
1550
  if self.last_batch.chunked_req is not None:
1511
1551
  # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
@@ -1552,11 +1592,7 @@ class Scheduler(
1552
1592
 
1553
1593
  # Handle DP attention
1554
1594
  if need_dp_attn_preparation:
1555
- if (
1556
- self.server_args.load_balance_method == "minimum_tokens"
1557
- and self.forward_ct % 40 == 0
1558
- ):
1559
- self.handle_dp_balance_data(ret)
1595
+ self.maybe_handle_dp_balance_data()
1560
1596
  ret = self.prepare_mlp_sync_batch(ret)
1561
1597
 
1562
1598
  return ret
@@ -1776,10 +1812,6 @@ class Scheduler(
1776
1812
  if self.spec_algorithm.is_none():
1777
1813
  model_worker_batch = batch.get_model_worker_batch()
1778
1814
 
1779
- # update the consumer index of hicache to the running batch
1780
- self.tp_worker.set_hicache_consumer(
1781
- model_worker_batch.hicache_consumer_index
1782
- )
1783
1815
  if self.pp_group.is_last_rank:
1784
1816
  logits_output, next_token_ids, can_run_cuda_graph = (
1785
1817
  self.tp_worker.forward_batch_generation(model_worker_batch)
@@ -1881,86 +1913,6 @@ class Scheduler(
1881
1913
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1882
1914
  )
1883
1915
 
1884
- def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1885
- def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1886
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1887
- recv_list = self.recv_dp_balance_id_this_term
1888
- assert len(recv_list) <= 511, (
1889
- "The number of requests received this round is too large. "
1890
- "Please increase gather_tensor_size and onfly_info_size."
1891
- )
1892
- # The maximum size of the tensor used for gathering data from all workers.
1893
- gather_tensor_size = 512
1894
-
1895
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1896
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1897
- recv_tensor[0] = holding_tokens_list
1898
- recv_tensor[1] = len(
1899
- recv_list
1900
- ) # The first element is the length of the list.
1901
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1902
- recv_list, dtype=torch.int32
1903
- )
1904
-
1905
- if self.tp_rank == 0:
1906
- gathered_list = [
1907
- torch.zeros(gather_tensor_size, dtype=torch.int32)
1908
- for _ in range(self.balance_meta.num_workers)
1909
- ]
1910
- else:
1911
- gathered_list = None
1912
-
1913
- torch.distributed.gather(
1914
- recv_tensor, gathered_list, group=self.tp_cpu_group
1915
- )
1916
-
1917
- gathered_id_list_per_worker = None
1918
- if self.tp_rank == 0:
1919
- gathered_id_list_per_worker = []
1920
- holding_tokens_list = []
1921
- for tensor in gathered_list:
1922
- holding_tokens_list.append(tensor[0].item())
1923
- list_length = tensor[1].item()
1924
- gathered_id_list_per_worker.append(
1925
- tensor[2 : list_length + 2].tolist()
1926
- )
1927
-
1928
- return gathered_id_list_per_worker, holding_tokens_list
1929
-
1930
- def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1931
- meta = self.balance_meta
1932
-
1933
- with meta.mutex:
1934
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1935
- assert len(new_recv_rid_lists) == len(
1936
- onfly_list
1937
- ), "num_worker not equal"
1938
- # 1.Check if the rid received by each worker this round is present in onfly.
1939
- # If it is, remove the corresponding onfly item.
1940
- worker_id = 0
1941
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1942
- for new_recv_rid in new_recv_rids:
1943
- assert (
1944
- new_recv_rid in on_fly_reqs
1945
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1946
- del on_fly_reqs[new_recv_rid]
1947
- worker_id += 1
1948
- # 2. Atomically write local_tokens and onfly into shm under the mutex
1949
- meta.set_shared_onfly_info(onfly_list)
1950
- meta.set_shared_local_tokens(local_tokens)
1951
-
1952
- holding_tokens = self.get_load()
1953
-
1954
- new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1955
- holding_tokens
1956
- )
1957
-
1958
- self.recv_dp_balance_id_this_term.clear()
1959
- if self.tp_rank == 0: # only first worker write info
1960
- write_shared_dp_balance_info(
1961
- new_recv_dp_balance_id_list, holding_token_list
1962
- )
1963
-
1964
1916
  @staticmethod
1965
1917
  def prepare_mlp_sync_batch_raw(
1966
1918
  local_batch: ScheduleBatch,
@@ -2207,6 +2159,16 @@ class Scheduler(
2207
2159
  success = self.flush_cache()
2208
2160
  return FlushCacheReqOutput(success=success)
2209
2161
 
2162
+ def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
2163
+ if self.enable_hierarchical_cache:
2164
+ self.tree_cache.clear_storage_backend()
2165
+ logger.info("Hierarchical cache cleared successfully!")
2166
+ if_success = True
2167
+ else:
2168
+ logging.warning("Hierarchical cache is not enabled.")
2169
+ if_success = False
2170
+ return ClearHiCacheReqOutput(success=if_success)
2171
+
2210
2172
  def flush_cache(self):
2211
2173
  """Flush the memory pool and cache."""
2212
2174
  if (
@@ -2291,10 +2253,9 @@ class Scheduler(
2291
2253
  "token_capacity": int(self.max_total_num_tokens),
2292
2254
  }
2293
2255
 
2294
- if not _is_cpu:
2295
- ret["memory_usage"]["cuda_graph"] = round(
2296
- self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2297
- )
2256
+ ret["memory_usage"]["graph"] = round(
2257
+ self.tp_worker.worker.model_runner.graph_mem_usage, 2
2258
+ )
2298
2259
 
2299
2260
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2300
2261
  ret["avg_spec_accept_length"] = (
@@ -2377,7 +2338,14 @@ class Scheduler(
2377
2338
  # This only works for requests that have not started anything.
2378
2339
  # We still need to send something back to TokenizerManager to clean up the state.
2379
2340
  req = self.waiting_queue.pop(i)
2341
+ if self.enable_hicache_storage:
2342
+ # to release prefetch events associated with the request
2343
+ self.tree_cache.release_aborted_request(req.rid)
2380
2344
  self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2345
+ # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2346
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
2347
+ self.tree_cache.cache_finished_req(req)
2348
+
2381
2349
  logger.debug(f"Abort queued request. {req.rid=}")
2382
2350
 
2383
2351
  # Delete the requests in the grammar queue
@@ -2457,6 +2425,10 @@ class Scheduler(
2457
2425
  result = self.tp_worker.unload_lora_adapter(recv_req)
2458
2426
  return result
2459
2427
 
2428
+ def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
2429
+ self.send_to_detokenizer.send_pyobj(recv_req)
2430
+ return recv_req
2431
+
2460
2432
  def slow_down(self, recv_req: SlowDownReqInput):
2461
2433
  t = recv_req.forward_sleep_time
2462
2434
  if t is not None and t <= 0:
@@ -2578,6 +2550,9 @@ def run_scheduler_process(
2578
2550
  pipe_writer,
2579
2551
  balance_meta: Optional[DPBalanceMeta] = None,
2580
2552
  ):
2553
+ if (numa_node := server_args.numa_node) is not None:
2554
+ numa_bind_to_node(numa_node[gpu_id])
2555
+
2581
2556
  # Generate the prefix
2582
2557
  prefix = ""
2583
2558
  if dp_rank is not None: