sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import threading
38
38
  from enum import Enum, auto
39
39
  from http import HTTPStatus
40
40
  from itertools import chain
41
- from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
41
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
42
42
 
43
43
  import numpy as np
44
44
  import torch
@@ -52,7 +52,6 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
52
  ScheduleBatchDisaggregationDecodeMixin,
53
53
  )
54
54
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
- from sglang.srt.layers.moe import is_tbo_enabled
56
55
  from sglang.srt.mem_cache.allocator import (
57
56
  BaseTokenToKVPoolAllocator,
58
57
  SWATokenToKVPoolAllocator,
@@ -60,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
60
59
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
61
60
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
62
61
  from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
63
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
62
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
64
63
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
65
64
  from sglang.srt.metrics.collector import TimeStats
66
65
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
@@ -99,6 +98,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
99
98
  "sampling_backend",
100
99
  "speculative_accept_threshold_single",
101
100
  "speculative_accept_threshold_acc",
101
+ "speculative_attention_mode",
102
102
  "torchao_config",
103
103
  "triton_attention_reduce_in_fp32",
104
104
  "num_reserved_decode_tokens",
@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
911
911
  is_prefill_only: bool = False
912
912
 
913
913
  # hicache pointer for synchronizing data loading from CPU to GPU
914
- hicache_consumer_index: int = 0
914
+ hicache_consumer_index: int = -1
915
915
 
916
916
  @classmethod
917
917
  def init_new(
@@ -962,8 +962,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
962
962
  def is_empty(self):
963
963
  return len(self.reqs) == 0
964
964
 
965
- def alloc_req_slots(self, num_reqs: int):
966
- req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
965
+ def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
966
+ if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
967
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
968
+ else:
969
+ req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
967
970
  if req_pool_indices is None:
968
971
  raise RuntimeError(
969
972
  "alloc_req_slots runs out of memory. "
@@ -1138,7 +1141,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1138
1141
 
1139
1142
  # Allocate req slots
1140
1143
  bs = len(self.reqs)
1141
- req_pool_indices = self.alloc_req_slots(bs)
1144
+ req_pool_indices = self.alloc_req_slots(bs, self.reqs)
1142
1145
 
1143
1146
  # Init tensors
1144
1147
  reqs = self.reqs
@@ -1372,21 +1375,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1372
1375
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1373
1376
  self.extend_logprob_start_lens.extend([0] * running_bs)
1374
1377
 
1375
- def new_page_count_next_decode(self):
1378
+ def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1376
1379
  page_size = self.token_to_kv_pool_allocator.page_size
1380
+ requests = (
1381
+ self.reqs
1382
+ if selected_indices is None
1383
+ else [self.reqs[i] for i in selected_indices]
1384
+ )
1377
1385
  if page_size == 1:
1378
- return len(self.reqs)
1386
+ return len(requests)
1379
1387
  # In the decoding phase, the length of a request's KV cache should be
1380
1388
  # the total length of the request minus 1
1381
1389
  return (
1382
- sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1390
+ sum(1 for req in requests if req.seqlen % page_size == 0)
1383
1391
  if self.enable_overlap
1384
- else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1392
+ else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
1385
1393
  )
1386
1394
 
1387
- def check_decode_mem(self, buf_multiplier=1):
1395
+ def check_decode_mem(
1396
+ self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
1397
+ ):
1388
1398
  num_tokens = (
1389
- self.new_page_count_next_decode()
1399
+ self.new_page_count_next_decode(selected_indices)
1390
1400
  * buf_multiplier
1391
1401
  * self.token_to_kv_pool_allocator.page_size
1392
1402
  )
@@ -1412,34 +1422,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1412
1422
  reverse=True,
1413
1423
  )
1414
1424
 
1415
- def get_required_tokens(num_reqs: int):
1416
- headroom_for_spec_decode = 0
1417
- if server_args.speculative_algorithm:
1418
- headroom_for_spec_decode += (
1419
- num_reqs
1420
- * server_args.speculative_eagle_topk
1421
- * server_args.speculative_num_steps
1422
- + num_reqs * server_args.speculative_num_draft_tokens
1423
- )
1424
- return (
1425
- num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
1426
- )
1427
-
1428
- def _get_available_size():
1429
- if self.is_hybrid:
1430
- return min(
1431
- self.token_to_kv_pool_allocator.full_available_size(),
1432
- self.token_to_kv_pool_allocator.swa_available_size(),
1433
- )
1434
- else:
1435
- return self.token_to_kv_pool_allocator.available_size()
1436
-
1437
1425
  retracted_reqs = []
1438
1426
  seq_lens_cpu = self.seq_lens.cpu().numpy()
1439
1427
  first_iter = True
1440
- while (
1441
- _get_available_size() < get_required_tokens(len(sorted_indices))
1442
- or first_iter
1428
+ while first_iter or (
1429
+ not self.check_decode_mem(selected_indices=sorted_indices)
1443
1430
  ):
1444
1431
  if len(sorted_indices) == 1:
1445
1432
  # Corner case: only one request left
@@ -1493,10 +1480,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1493
1480
  else:
1494
1481
  self.tree_cache.dec_lock_ref(req.last_node)
1495
1482
 
1496
- # NOTE(lsyin): we should use the newly evictable memory instantly.
1497
- num_tokens = len(sorted_indices) * global_config.retract_decode_steps
1498
- self._evict_tree_cache_if_needed(num_tokens)
1499
-
1500
1483
  req.reset_for_retract()
1501
1484
 
1502
1485
  if len(retracted_reqs) == 0:
@@ -1540,7 +1523,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1540
1523
  self.forward_mode = ForwardMode.DECODE
1541
1524
  bs = len(self.reqs)
1542
1525
 
1543
- if self.spec_algorithm.is_eagle():
1526
+ if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1544
1527
  # if spec decoding is used, the decode batch is prepared inside
1545
1528
  # `forward_batch_speculative_generation` after running draft models.
1546
1529
  return
@@ -1917,7 +1900,7 @@ class ModelWorkerBatch:
1917
1900
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1918
1901
  # If set, the output of the batch contains the hidden states of the run.
1919
1902
  capture_hidden_mode: CaptureHiddenMode = None
1920
- hicache_consumer_index: int = 0
1903
+ hicache_consumer_index: int = -1
1921
1904
 
1922
1905
  # Overlap event
1923
1906
  launch_done: Optional[threading.Event] = None
@@ -380,8 +380,9 @@ class PrefillAdder:
380
380
  self.log_input_tokens += extend_input_len
381
381
 
382
382
  def add_chunked_req(self, req: Req):
383
- truncated = req.extend_input_len > self.rem_chunk_tokens
384
- req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
383
+ _rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
384
+ truncated = req.extend_input_len > _rem_tokens
385
+ req.extend_input_len = min(req.extend_input_len, _rem_tokens)
385
386
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
386
387
  self.can_run_list.append(req)
387
388
  self._update_prefill_budget(
@@ -549,7 +550,7 @@ class PrefillAdder:
549
550
  )
550
551
  else:
551
552
  # Make sure at least one page is available
552
- trunc_len = self.rem_chunk_tokens - self.page_size + 1
553
+ trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size
553
554
  if trunc_len <= 0:
554
555
  return AddReqResult.OTHER
555
556
 
@@ -67,6 +67,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
67
  from sglang.srt.layers.moe import initialize_moe_config
68
68
  from sglang.srt.managers.io_struct import (
69
69
  AbortReq,
70
+ BatchTokenizedEmbeddingReqInput,
71
+ BatchTokenizedGenerateReqInput,
72
+ ClearHiCacheReqInput,
73
+ ClearHiCacheReqOutput,
70
74
  CloseSessionReqInput,
71
75
  ExpertDistributionReq,
72
76
  ExpertDistributionReqOutput,
@@ -80,6 +84,8 @@ from sglang.srt.managers.io_struct import (
80
84
  InitWeightsUpdateGroupReqInput,
81
85
  LoadLoRAAdapterReqInput,
82
86
  LoadLoRAAdapterReqOutput,
87
+ MultiTokenizerRegisterReq,
88
+ MultiTokenizerWrapper,
83
89
  OpenSessionReqInput,
84
90
  OpenSessionReqOutput,
85
91
  ProfileReq,
@@ -135,7 +141,7 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
135
141
  from sglang.srt.mem_cache.radix_cache import RadixCache
136
142
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
137
143
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
138
- from sglang.srt.reasoning_parser import ReasoningParser
144
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
139
145
  from sglang.srt.server_args import PortArgs, ServerArgs
140
146
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
141
147
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -152,6 +158,7 @@ from sglang.srt.utils import (
152
158
  get_zmq_socket,
153
159
  is_cpu,
154
160
  kill_itself_when_parent_died,
161
+ numa_bind_to_node,
155
162
  point_to_point_pyobj,
156
163
  pyspy_dump_schedulers,
157
164
  require_mlp_sync,
@@ -253,7 +260,6 @@ class Scheduler(
253
260
  # Init inter-process communication
254
261
  context = zmq.Context(2)
255
262
  self.idle_sleeper = None
256
-
257
263
  if self.pp_rank == 0 and self.attn_tp_rank == 0:
258
264
  self.recv_from_tokenizer = get_zmq_socket(
259
265
  context, zmq.PULL, port_args.scheduler_input_ipc_name, False
@@ -343,6 +349,18 @@ class Scheduler(
343
349
  target_worker=self.tp_worker,
344
350
  dp_rank=dp_rank,
345
351
  )
352
+ elif self.spec_algorithm.is_standalone():
353
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
354
+
355
+ self.draft_worker = StandaloneWorker(
356
+ gpu_id=gpu_id,
357
+ tp_rank=tp_rank,
358
+ moe_ep_rank=moe_ep_rank,
359
+ server_args=server_args,
360
+ nccl_port=port_args.nccl_port,
361
+ target_worker=self.tp_worker,
362
+ dp_rank=dp_rank,
363
+ )
346
364
  else:
347
365
  self.draft_worker = None
348
366
 
@@ -396,7 +414,7 @@ class Scheduler(
396
414
  f"max_prefill_tokens={self.max_prefill_tokens}, "
397
415
  f"max_running_requests={self.max_running_requests}, "
398
416
  f"context_len={self.model_config.context_len}, "
399
- f"available_gpu_mem={avail_mem:.2f} GB"
417
+ f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
400
418
  )
401
419
 
402
420
  # Init memory pool and cache
@@ -483,7 +501,7 @@ class Scheduler(
483
501
  enable=server_args.enable_memory_saver
484
502
  )
485
503
  self.offload_tags = set()
486
- self.init_profier()
504
+ self.init_profiler()
487
505
 
488
506
  self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
489
507
  self.input_blocker = (
@@ -495,6 +513,7 @@ class Scheduler(
495
513
  # Init metrics stats
496
514
  self.init_metrics(tp_rank, pp_rank, dp_rank)
497
515
  self.init_kv_events(server_args.kv_events_config)
516
+ self.init_dp_balance(dp_balance_meta)
498
517
 
499
518
  # Init disaggregation
500
519
  self.disaggregation_mode = DisaggregationMode(
@@ -510,7 +529,10 @@ class Scheduler(
510
529
  [
511
530
  (TokenizedGenerateReqInput, self.handle_generate_request),
512
531
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
532
+ (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
533
+ (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
513
534
  (FlushCacheReqInput, self.flush_cache_wrapped),
535
+ (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
514
536
  (AbortReq, self.abort_request),
515
537
  (OpenSessionReqInput, self.open_session),
516
538
  (CloseSessionReqInput, self.close_session),
@@ -533,18 +555,10 @@ class Scheduler(
533
555
  (ExpertDistributionReq, self.expert_distribution_handle),
534
556
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
535
557
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
558
+ (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
536
559
  ]
537
560
  )
538
561
 
539
- self.balance_meta = dp_balance_meta
540
- if (
541
- server_args.enable_dp_attention
542
- and server_args.load_balance_method == "minimum_tokens"
543
- ):
544
- assert dp_balance_meta is not None
545
-
546
- self.recv_dp_balance_id_this_term = []
547
-
548
562
  def init_tokenizer(self):
549
563
  server_args = self.server_args
550
564
  self.is_generation = self.model_config.is_generation
@@ -621,8 +635,11 @@ class Scheduler(
621
635
  hicache_write_policy=server_args.hicache_write_policy,
622
636
  hicache_io_backend=server_args.hicache_io_backend,
623
637
  hicache_mem_layout=server_args.hicache_mem_layout,
638
+ enable_metrics=self.enable_metrics,
624
639
  hicache_storage_backend=server_args.hicache_storage_backend,
625
640
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
641
+ model_name=server_args.served_model_name,
642
+ storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
626
643
  )
627
644
  self.tp_worker.register_hicache_layer_transfer_counter(
628
645
  self.tree_cache.cache_controller.layer_done_counter
@@ -651,6 +668,21 @@ class Scheduler(
651
668
  page_size=self.page_size,
652
669
  disable=server_args.disable_radix_cache,
653
670
  )
671
+ elif server_args.enable_lmcache:
672
+ from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
673
+ LMCRadixCache,
674
+ )
675
+
676
+ self.tree_cache = LMCRadixCache(
677
+ req_to_token_pool=self.req_to_token_pool,
678
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
679
+ page_size=self.page_size,
680
+ disable=server_args.disable_radix_cache,
681
+ model_config=self.model_config,
682
+ tp_size=self.tp_size,
683
+ rank=self.tp_rank,
684
+ tp_group=self.tp_group,
685
+ )
654
686
  else:
655
687
  self.tree_cache = RadixCache(
656
688
  req_to_token_pool=self.req_to_token_pool,
@@ -1018,14 +1050,26 @@ class Scheduler(
1018
1050
  req
1019
1051
  for req in recv_reqs
1020
1052
  if isinstance(
1021
- req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1053
+ req,
1054
+ (
1055
+ TokenizedGenerateReqInput,
1056
+ TokenizedEmbeddingReqInput,
1057
+ BatchTokenizedGenerateReqInput,
1058
+ BatchTokenizedEmbeddingReqInput,
1059
+ ),
1022
1060
  )
1023
1061
  ]
1024
1062
  control_reqs = [
1025
1063
  req
1026
1064
  for req in recv_reqs
1027
1065
  if not isinstance(
1028
- req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1066
+ req,
1067
+ (
1068
+ TokenizedGenerateReqInput,
1069
+ TokenizedEmbeddingReqInput,
1070
+ BatchTokenizedGenerateReqInput,
1071
+ BatchTokenizedEmbeddingReqInput,
1072
+ ),
1029
1073
  )
1030
1074
  ]
1031
1075
  else:
@@ -1080,6 +1124,17 @@ class Scheduler(
1080
1124
  )
1081
1125
  self.send_to_tokenizer.send_pyobj(abort_req)
1082
1126
  continue
1127
+
1128
+ # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1129
+ if isinstance(recv_req, MultiTokenizerWrapper):
1130
+ worker_id = recv_req.worker_id
1131
+ recv_req = recv_req.obj
1132
+ output = self._request_dispatcher(recv_req)
1133
+ if output is not None:
1134
+ output = MultiTokenizerWrapper(worker_id, output)
1135
+ self.send_to_tokenizer.send_pyobj(output)
1136
+ continue
1137
+
1083
1138
  output = self._request_dispatcher(recv_req)
1084
1139
  if output is not None:
1085
1140
  if isinstance(output, RpcReqOutput):
@@ -1092,11 +1147,7 @@ class Scheduler(
1092
1147
  self,
1093
1148
  recv_req: TokenizedGenerateReqInput,
1094
1149
  ):
1095
- if (
1096
- self.server_args.enable_dp_attention
1097
- and self.server_args.load_balance_method == "minimum_tokens"
1098
- ):
1099
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1150
+ self.maybe_update_dp_balance_data(recv_req)
1100
1151
 
1101
1152
  # Create a new request
1102
1153
  if (
@@ -1253,6 +1304,17 @@ class Scheduler(
1253
1304
  else:
1254
1305
  self._add_request_to_queue(req)
1255
1306
 
1307
+ def handle_batch_generate_request(
1308
+ self,
1309
+ recv_req: BatchTokenizedGenerateReqInput,
1310
+ ):
1311
+ """Handle optimized batch generate request."""
1312
+ logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
1313
+
1314
+ # Process each request in the batch
1315
+ for tokenized_req in recv_req:
1316
+ self.handle_generate_request(tokenized_req)
1317
+
1256
1318
  def _add_request_to_queue(self, req: Req):
1257
1319
  req.queue_time_start = time.perf_counter()
1258
1320
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1269,10 +1331,11 @@ class Scheduler(
1269
1331
  def _prefetch_kvcache(self, req: Req):
1270
1332
  if self.enable_hicache_storage:
1271
1333
  req.init_next_round_input(self.tree_cache)
1272
- last_hash = req.last_host_node.get_last_hash_value()
1273
- matched_len = len(req.prefix_indices) + req.host_hit_length
1274
- # todo, free-form fetching, calculating hash keys on the fly
1275
- if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1334
+ if req.last_node.backuped:
1335
+ # only to initiate the prefetch if the last node is backuped
1336
+ # otherwise, the allocated GPU memory must be locked for integrity
1337
+ last_hash = req.last_host_node.get_last_hash_value()
1338
+ matched_len = len(req.prefix_indices) + req.host_hit_length
1276
1339
  new_input_tokens = req.fill_ids[matched_len:]
1277
1340
  self.tree_cache.prefetch_from_storage(
1278
1341
  req.rid, req.last_host_node, new_input_tokens, last_hash
@@ -1335,6 +1398,19 @@ class Scheduler(
1335
1398
  req.logprob_start_len = len(req.origin_input_ids) - 1
1336
1399
  self._add_request_to_queue(req)
1337
1400
 
1401
+ def handle_batch_embedding_request(
1402
+ self,
1403
+ recv_req: BatchTokenizedEmbeddingReqInput,
1404
+ ):
1405
+ """Handle optimized batch embedding request."""
1406
+ logger.debug(
1407
+ f"Processing batch embedding request with {len(recv_req)} requests"
1408
+ )
1409
+
1410
+ # Process each request in the batch
1411
+ for tokenized_req in recv_req:
1412
+ self.handle_embedding_request(tokenized_req)
1413
+
1338
1414
  def self_check_during_idle(self):
1339
1415
  self.check_memory()
1340
1416
  self.check_tree_cache()
@@ -1362,9 +1438,11 @@ class Scheduler(
1362
1438
  _, _, available_size, evictable_size = self._get_token_info()
1363
1439
  protected_size = self.tree_cache.protected_size()
1364
1440
  memory_leak = (available_size + evictable_size) != (
1441
+ # self.max_total_num_tokens
1442
+ # if not self.enable_hierarchical_cache
1443
+ # else self.max_total_num_tokens - protected_size
1365
1444
  self.max_total_num_tokens
1366
- if not self.enable_hierarchical_cache
1367
- else self.max_total_num_tokens - protected_size
1445
+ - protected_size
1368
1446
  )
1369
1447
  token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1370
1448
 
@@ -1460,9 +1538,14 @@ class Scheduler(
1460
1538
  # Move the chunked request out of the batch so that we can merge
1461
1539
  # only finished requests to running_batch.
1462
1540
  chunked_req_to_exclude.add(self.chunked_req)
1463
- self.tree_cache.cache_unfinished_req(self.chunked_req)
1541
+ self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1464
1542
  # chunked request keeps its rid but will get a new req_pool_idx
1465
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1543
+ if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1544
+ self.req_to_token_pool.free(
1545
+ self.chunked_req.req_pool_idx, free_mamba_cache=False
1546
+ )
1547
+ else:
1548
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1466
1549
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1467
1550
  if self.last_batch.chunked_req is not None:
1468
1551
  # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
@@ -1509,11 +1592,7 @@ class Scheduler(
1509
1592
 
1510
1593
  # Handle DP attention
1511
1594
  if need_dp_attn_preparation:
1512
- if (
1513
- self.server_args.load_balance_method == "minimum_tokens"
1514
- and self.forward_ct % 40 == 0
1515
- ):
1516
- self.handle_dp_balance_data(ret)
1595
+ self.maybe_handle_dp_balance_data()
1517
1596
  ret = self.prepare_mlp_sync_batch(ret)
1518
1597
 
1519
1598
  return ret
@@ -1733,10 +1812,6 @@ class Scheduler(
1733
1812
  if self.spec_algorithm.is_none():
1734
1813
  model_worker_batch = batch.get_model_worker_batch()
1735
1814
 
1736
- # update the consumer index of hicache to the running batch
1737
- self.tp_worker.set_hicache_consumer(
1738
- model_worker_batch.hicache_consumer_index
1739
- )
1740
1815
  if self.pp_group.is_last_rank:
1741
1816
  logits_output, next_token_ids, can_run_cuda_graph = (
1742
1817
  self.tp_worker.forward_batch_generation(model_worker_batch)
@@ -1838,86 +1913,6 @@ class Scheduler(
1838
1913
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1839
1914
  )
1840
1915
 
1841
- def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1842
- def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1843
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1844
- recv_list = self.recv_dp_balance_id_this_term
1845
- assert len(recv_list) <= 511, (
1846
- "The number of requests received this round is too large. "
1847
- "Please increase gather_tensor_size and onfly_info_size."
1848
- )
1849
- # The maximum size of the tensor used for gathering data from all workers.
1850
- gather_tensor_size = 512
1851
-
1852
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1853
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1854
- recv_tensor[0] = holding_tokens_list
1855
- recv_tensor[1] = len(
1856
- recv_list
1857
- ) # The first element is the length of the list.
1858
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1859
- recv_list, dtype=torch.int32
1860
- )
1861
-
1862
- if self.tp_rank == 0:
1863
- gathered_list = [
1864
- torch.zeros(gather_tensor_size, dtype=torch.int32)
1865
- for _ in range(self.balance_meta.num_workers)
1866
- ]
1867
- else:
1868
- gathered_list = None
1869
-
1870
- torch.distributed.gather(
1871
- recv_tensor, gathered_list, group=self.tp_cpu_group
1872
- )
1873
-
1874
- gathered_id_list_per_worker = None
1875
- if self.tp_rank == 0:
1876
- gathered_id_list_per_worker = []
1877
- holding_tokens_list = []
1878
- for tensor in gathered_list:
1879
- holding_tokens_list.append(tensor[0].item())
1880
- list_length = tensor[1].item()
1881
- gathered_id_list_per_worker.append(
1882
- tensor[2 : list_length + 2].tolist()
1883
- )
1884
-
1885
- return gathered_id_list_per_worker, holding_tokens_list
1886
-
1887
- def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1888
- meta = self.balance_meta
1889
-
1890
- with meta.mutex:
1891
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1892
- assert len(new_recv_rid_lists) == len(
1893
- onfly_list
1894
- ), "num_worker not equal"
1895
- # 1.Check if the rid received by each worker this round is present in onfly.
1896
- # If it is, remove the corresponding onfly item.
1897
- worker_id = 0
1898
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1899
- for new_recv_rid in new_recv_rids:
1900
- assert (
1901
- new_recv_rid in on_fly_reqs
1902
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1903
- del on_fly_reqs[new_recv_rid]
1904
- worker_id += 1
1905
- # 2. Atomically write local_tokens and onfly into shm under the mutex
1906
- meta.set_shared_onfly_info(onfly_list)
1907
- meta.set_shared_local_tokens(local_tokens)
1908
-
1909
- holding_tokens = self.get_load()
1910
-
1911
- new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1912
- holding_tokens
1913
- )
1914
-
1915
- self.recv_dp_balance_id_this_term.clear()
1916
- if self.tp_rank == 0: # only first worker write info
1917
- write_shared_dp_balance_info(
1918
- new_recv_dp_balance_id_list, holding_token_list
1919
- )
1920
-
1921
1916
  @staticmethod
1922
1917
  def prepare_mlp_sync_batch_raw(
1923
1918
  local_batch: ScheduleBatch,
@@ -2164,6 +2159,16 @@ class Scheduler(
2164
2159
  success = self.flush_cache()
2165
2160
  return FlushCacheReqOutput(success=success)
2166
2161
 
2162
+ def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
2163
+ if self.enable_hierarchical_cache:
2164
+ self.tree_cache.clear_storage_backend()
2165
+ logger.info("Hierarchical cache cleared successfully!")
2166
+ if_success = True
2167
+ else:
2168
+ logging.warning("Hierarchical cache is not enabled.")
2169
+ if_success = False
2170
+ return ClearHiCacheReqOutput(success=if_success)
2171
+
2167
2172
  def flush_cache(self):
2168
2173
  """Flush the memory pool and cache."""
2169
2174
  if (
@@ -2248,10 +2253,9 @@ class Scheduler(
2248
2253
  "token_capacity": int(self.max_total_num_tokens),
2249
2254
  }
2250
2255
 
2251
- if not _is_cpu:
2252
- ret["memory_usage"]["cuda_graph"] = round(
2253
- self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2254
- )
2256
+ ret["memory_usage"]["graph"] = round(
2257
+ self.tp_worker.worker.model_runner.graph_mem_usage, 2
2258
+ )
2255
2259
 
2256
2260
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2257
2261
  ret["avg_spec_accept_length"] = (
@@ -2334,7 +2338,14 @@ class Scheduler(
2334
2338
  # This only works for requests that have not started anything.
2335
2339
  # We still need to send something back to TokenizerManager to clean up the state.
2336
2340
  req = self.waiting_queue.pop(i)
2341
+ if self.enable_hicache_storage:
2342
+ # to release prefetch events associated with the request
2343
+ self.tree_cache.release_aborted_request(req.rid)
2337
2344
  self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2345
+ # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2346
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
2347
+ self.tree_cache.cache_finished_req(req)
2348
+
2338
2349
  logger.debug(f"Abort queued request. {req.rid=}")
2339
2350
 
2340
2351
  # Delete the requests in the grammar queue
@@ -2414,6 +2425,10 @@ class Scheduler(
2414
2425
  result = self.tp_worker.unload_lora_adapter(recv_req)
2415
2426
  return result
2416
2427
 
2428
+ def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
2429
+ self.send_to_detokenizer.send_pyobj(recv_req)
2430
+ return recv_req
2431
+
2417
2432
  def slow_down(self, recv_req: SlowDownReqInput):
2418
2433
  t = recv_req.forward_sleep_time
2419
2434
  if t is not None and t <= 0:
@@ -2513,7 +2528,15 @@ def is_health_check_generate_req(recv_req):
2513
2528
 
2514
2529
 
2515
2530
  def is_work_request(recv_req):
2516
- return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2531
+ return isinstance(
2532
+ recv_req,
2533
+ (
2534
+ TokenizedGenerateReqInput,
2535
+ TokenizedEmbeddingReqInput,
2536
+ BatchTokenizedGenerateReqInput,
2537
+ BatchTokenizedEmbeddingReqInput,
2538
+ ),
2539
+ )
2517
2540
 
2518
2541
 
2519
2542
  def run_scheduler_process(
@@ -2527,6 +2550,9 @@ def run_scheduler_process(
2527
2550
  pipe_writer,
2528
2551
  balance_meta: Optional[DPBalanceMeta] = None,
2529
2552
  ):
2553
+ if (numa_node := server_args.numa_node) is not None:
2554
+ numa_bind_to_node(numa_node[gpu_id])
2555
+
2530
2556
  # Generate the prefix
2531
2557
  prefix = ""
2532
2558
  if dp_rank is not None: