sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,640 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """Run the model with cpu torch compile."""
15
+
16
+ # The implementation of CPUGraphRunner follows the CudaGraphRunner
17
+
18
+ from __future__ import annotations
19
+
20
+ import logging
21
+ from contextlib import contextmanager
22
+ from typing import TYPE_CHECKING, Callable, Optional, Union
23
+
24
+ import psutil
25
+ import torch
26
+ import tqdm
27
+
28
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
29
+ from sglang.srt.distributed.parallel_state import GroupCoordinator
30
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
31
+ from sglang.srt.model_executor.forward_batch_info import (
32
+ CaptureHiddenMode,
33
+ ForwardBatch,
34
+ ForwardMode,
35
+ PPProxyTensors,
36
+ )
37
+ from sglang.srt.patch_torch import monkey_patch_torch_compile
38
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
39
+ from sglang.srt.utils import (
40
+ log_info_on_rank0,
41
+ require_attn_tp_gather,
42
+ require_gathered_buffer,
43
+ require_mlp_sync,
44
+ require_mlp_tp_gather,
45
+ )
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+ if TYPE_CHECKING:
50
+ from sglang.srt.model_executor.model_runner import ModelRunner
51
+
52
+
53
+ @contextmanager
54
+ def patch_model(
55
+ model: torch.nn.Module,
56
+ enable_compile: bool,
57
+ num_tokens: int,
58
+ tp_group: GroupCoordinator,
59
+ ):
60
+ """Patch the model to make it compatible with torch.compile"""
61
+ backup_ca_comm = None
62
+
63
+ try:
64
+ if enable_compile:
65
+ backup_ca_comm = tp_group.ca_comm
66
+ # Use custom-allreduce here.
67
+ # We found the custom allreduce is much faster than the built-in allreduce in torch,
68
+ # even with ENABLE_INTRA_NODE_COMM=1.
69
+ # tp_group.ca_comm = None
70
+ yield torch.compile(
71
+ torch.no_grad()(model.forward),
72
+ dynamic=False,
73
+ )
74
+ else:
75
+ yield model.forward
76
+ finally:
77
+ if enable_compile:
78
+ tp_group.ca_comm = backup_ca_comm
79
+
80
+
81
+ def set_torch_compile_config():
82
+ import torch._dynamo.config
83
+ import torch._inductor.config
84
+
85
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
86
+ torch._inductor.config.freezing = True
87
+ torch._dynamo.config.accumulated_cache_size_limit = 1024
88
+ if hasattr(torch._dynamo.config, "cache_size_limit"):
89
+ torch._dynamo.config.cache_size_limit = 1024
90
+ monkey_patch_torch_compile()
91
+
92
+
93
+ def get_batch_sizes_to_capture(model_runner: ModelRunner):
94
+ server_args = model_runner.server_args
95
+ # cpu torch compile only speeds up decoding by
96
+ # reducing python overhead when bs is small
97
+ capture_bs = list(range(1, 17))
98
+ capture_bs = [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
99
+ capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
100
+ capture_bs = list(sorted(set(capture_bs)))
101
+ assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
102
+ return capture_bs
103
+
104
+
105
+ def register_fake_ops():
106
+ """
107
+ Registers fake/meta implementations for all custom sgl_kernel CPU operators
108
+ using torch.library.register_fake to support torch.compile
109
+ """
110
+
111
+ none_return_ops = [
112
+ "shm_allreduce",
113
+ "bmm_cpu",
114
+ "fused_add_rmsnorm_cpu",
115
+ "decode_attention_cpu",
116
+ "extend_attention_cpu",
117
+ ]
118
+ for op in none_return_ops:
119
+
120
+ @torch.library.register_fake(f"sgl_kernel::{op}")
121
+ def _(*args, **kwargs):
122
+ return
123
+
124
+ for op in [
125
+ "rmsnorm_cpu",
126
+ "l2norm_cpu",
127
+ "fused_experts_cpu",
128
+ "shared_expert_cpu",
129
+ ]:
130
+
131
+ @torch.library.register_fake(f"sgl_kernel::{op}")
132
+ def _(input, *args, **kwargs):
133
+ return torch.empty_like(input)
134
+
135
+ @torch.library.register_fake("sgl_kernel::qkv_proj_with_rope")
136
+ def _(
137
+ hidden_states,
138
+ q_a_proj_weight,
139
+ q_b_proj_weight,
140
+ kv_a_proj_weight,
141
+ w_kc,
142
+ q_a_layernorm_weight,
143
+ kv_a_layernorm_weight,
144
+ positions,
145
+ cos_sin_cache,
146
+ eps,
147
+ use_int8_w8a8,
148
+ use_fp8_w8a16,
149
+ q_a_proj_scale,
150
+ q_b_proj_scale,
151
+ kv_a_proj_scale,
152
+ is_vnni,
153
+ block_size,
154
+ ):
155
+ num_seqs = hidden_states.shape[0]
156
+ num_heads = w_kc.shape[0]
157
+ kv_lora_rank = w_kc.shape[1]
158
+ qk_rope_head_dim = kv_a_proj_weight.shape[0] - kv_lora_rank
159
+ q_input = torch.empty(
160
+ num_seqs,
161
+ num_heads,
162
+ kv_lora_rank + qk_rope_head_dim,
163
+ dtype=hidden_states.dtype,
164
+ device=hidden_states.device,
165
+ )
166
+ k_input = torch.empty(
167
+ num_seqs,
168
+ 1,
169
+ kv_lora_rank + qk_rope_head_dim,
170
+ dtype=hidden_states.dtype,
171
+ device=hidden_states.device,
172
+ )
173
+ v_input = k_input.narrow(-1, 0, kv_lora_rank)
174
+ return q_input, k_input, v_input
175
+
176
+ @torch.library.register_fake("sgl_kernel::rotary_embedding_cpu")
177
+ def _(positions, query, key, head_size, cos_sin_cache, is_neox):
178
+ if query.ndim == 2:
179
+ return query, key
180
+ else:
181
+ return torch.empty_like(query), torch.empty_like(key)
182
+
183
+ @torch.library.register_fake("sgl_kernel::qkv_proj_with_rope_fused_weight")
184
+ def _(
185
+ hidden_states,
186
+ q_a_proj_weight,
187
+ q_b_proj_weight,
188
+ w_kc,
189
+ q_a_layernorm_weight,
190
+ kv_a_layernorm_weight,
191
+ positions,
192
+ cos_sin_cache,
193
+ eps,
194
+ use_int8_w8a8,
195
+ use_fp8_w8a16,
196
+ qkv_a_proj_scale,
197
+ q_b_proj_scale,
198
+ is_vnni,
199
+ block_size,
200
+ q_lora_rank,
201
+ kv_lora_rank,
202
+ qk_rope_head_dim,
203
+ ):
204
+ num_seqs = hidden_states.shape[0]
205
+ num_heads = w_kc.shape[0]
206
+ kv_lora_rank = w_kc.shape[1]
207
+ weight_chunks = torch.split(
208
+ q_a_proj_weight, [q_lora_rank, kv_lora_rank + qk_rope_head_dim], dim=0
209
+ )
210
+ qk_rope_head_dim = weight_chunks[1].shape[0] - kv_lora_rank
211
+ q_input = torch.empty(
212
+ num_seqs,
213
+ num_heads,
214
+ kv_lora_rank + qk_rope_head_dim,
215
+ dtype=hidden_states.dtype,
216
+ device=hidden_states.device,
217
+ )
218
+ k_input = torch.empty(
219
+ num_seqs,
220
+ 1,
221
+ kv_lora_rank + qk_rope_head_dim,
222
+ dtype=hidden_states.dtype,
223
+ device=hidden_states.device,
224
+ )
225
+ v_input = k_input.narrow(-1, 0, kv_lora_rank)
226
+ return q_input, k_input, v_input
227
+
228
+ @torch.library.register_fake("sgl_kernel::weight_packed_linear")
229
+ def _(x, weight, bias, is_vnni):
230
+ return x.new_empty(x.shape[0], weight.shape[0])
231
+
232
+ @torch.library.register_fake("sgl_kernel::per_token_quant_int8_cpu")
233
+ def _(input):
234
+ M = input.shape[0]
235
+ K = input.shape[1]
236
+ Aq = input.new_empty(M, K, dtype=torch.int8)
237
+ As = input.new_empty(M, dtype=torch.float32)
238
+ return Aq, As
239
+
240
+ @torch.library.register_fake("sgl_kernel::int8_scaled_mm_cpu")
241
+ def _(mat1, mat2, scales1, scales2, bias, out_dtype, is_vnni):
242
+ M = mat1.shape[0]
243
+ N = mat2.shape[0]
244
+ out = mat1.new_empty(M, N, dtype=out_dtype)
245
+ return out
246
+
247
+ @torch.library.register_fake("sgl_kernel::grouped_topk_cpu")
248
+ def _(
249
+ hidden_states,
250
+ gating_output,
251
+ topk,
252
+ renormalize,
253
+ num_expert_group,
254
+ topk_group,
255
+ num_fused_shared_experts,
256
+ routed_scaling_factor,
257
+ num_token_non_padded,
258
+ ):
259
+ num_tokens = hidden_states.shape[0]
260
+ shape = (num_tokens, topk)
261
+ device = hidden_states.device
262
+ topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
263
+ topk_ids = torch.empty(shape, device=device, dtype=torch.int)
264
+ return topk_weights, topk_ids
265
+
266
+ @torch.library.register_fake("sgl_kernel::biased_grouped_topk_cpu")
267
+ def _(
268
+ hidden_states,
269
+ gating_output,
270
+ correction_bias,
271
+ topk,
272
+ renormalize,
273
+ num_expert_group,
274
+ topk_group,
275
+ num_fused_shared_experts,
276
+ routed_scaling_factor,
277
+ num_token_non_padded,
278
+ ):
279
+ num_tokens = hidden_states.shape[0]
280
+ shape = (num_tokens, topk)
281
+ device = hidden_states.device
282
+ topk_weights = torch.empty(shape, device=device, dtype=torch.float32)
283
+ topk_ids = torch.empty(shape, device=device, dtype=torch.int)
284
+ return topk_weights, topk_ids
285
+
286
+ @torch.library.register_fake("sgl_kernel::topk_sigmoid_cpu")
287
+ def _(hidden_states, gating_output, topk, renormalize):
288
+ num_tokens = hidden_states.shape[0]
289
+ shape = (num_tokens, topk)
290
+ return (
291
+ torch.empty(shape, device=hidden_states.device, dtype=torch.float),
292
+ torch.empty(shape, device=hidden_states.device, dtype=torch.int),
293
+ )
294
+
295
+ @torch.library.register_fake("sgl_kernel::topk_softmax_cpu")
296
+ def _(
297
+ hidden_states,
298
+ gating_output,
299
+ topk,
300
+ renormalize,
301
+ ):
302
+ num_tokens = hidden_states.shape[0]
303
+ shape = (num_tokens, topk)
304
+ return (
305
+ torch.empty(shape, device=hidden_states.device, dtype=torch.float),
306
+ torch.empty(shape, device=hidden_states.device, dtype=torch.int),
307
+ )
308
+
309
+ @torch.library.register_fake("sgl_kernel::silu_and_mul_cpu")
310
+ def _(input):
311
+ return input.new_empty(input.shape[0], input.shape[1] // 2)
312
+
313
+ @torch.library.register_fake("sgl_kernel::int8_scaled_mm_with_quant")
314
+ def _(
315
+ mat1,
316
+ mat2,
317
+ scales2,
318
+ bias,
319
+ out_dtype,
320
+ is_vnni,
321
+ ):
322
+ M = mat1.shape[0]
323
+ N = mat2.shape[0]
324
+ return mat1.new_empty(M, N, dtype=out_dtype)
325
+
326
+ @torch.library.register_fake("sgl_kernel::fp8_scaled_mm_cpu")
327
+ def _(
328
+ mat1,
329
+ mat2,
330
+ scales2,
331
+ block_size,
332
+ bias,
333
+ out_dtype,
334
+ is_vnni,
335
+ ):
336
+ M = mat1.shape[0]
337
+ N = mat2.shape[0]
338
+ return mat1.new_empty(M, N, dtype=out_dtype)
339
+
340
+
341
+ # TODO Remove unnecessary settings for CPUGraphRunner.
342
+ # Re-abstract the graph runner and restructure CPUGraphRunner to reuse the same logic.
343
+ class CPUGraphRunner:
344
+ """A CPUGraphRunner runs the forward pass of a model with cpu torch.compile."""
345
+
346
+ def __init__(self, model_runner: ModelRunner):
347
+ # Parse args
348
+ self.model_runner = model_runner
349
+ self.device = model_runner.device
350
+ self.graphs = {}
351
+ self.output_buffers = {}
352
+ self.enable_torch_compile = model_runner.server_args.enable_torch_compile
353
+ self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
354
+ self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
355
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
356
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
357
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
358
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
359
+ self.enable_two_batch_overlap = (
360
+ model_runner.server_args.enable_two_batch_overlap
361
+ )
362
+ self.speculative_algorithm = model_runner.server_args.speculative_algorithm
363
+ self.enable_profile_cuda_graph = (
364
+ model_runner.server_args.enable_profile_cuda_graph
365
+ )
366
+ self.tp_size = model_runner.server_args.tp_size
367
+ self.dp_size = model_runner.server_args.dp_size
368
+ self.pp_size = model_runner.server_args.pp_size
369
+
370
+ self.capture_forward_mode = ForwardMode.DECODE
371
+ self.capture_hidden_mode = CaptureHiddenMode.NULL
372
+ self.num_tokens_per_bs = 1
373
+
374
+ # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
375
+ if model_runner.server_args.enable_return_hidden_states:
376
+ self.capture_hidden_mode = CaptureHiddenMode.FULL
377
+
378
+ assert (
379
+ not self.model_runner.server_args.enable_lora
380
+ ), "CPUGraphRunner does not support LoRA yet."
381
+ assert (
382
+ not self.enable_two_batch_overlap
383
+ ), "CPUGraphRunner does not support two batch overlap yet."
384
+ assert (
385
+ not self.require_mlp_tp_gather
386
+ ), "CPUGraphRunner does not support MLP TP gather yet."
387
+ assert (
388
+ not self.require_mlp_sync
389
+ ), "CPUGraphRunner does not support MLP sync yet."
390
+ assert (
391
+ not self.require_gathered_buffer
392
+ ), "CPUGraphRunner does not support gathered buffer yet."
393
+ assert (
394
+ model_runner.spec_algorithm == SpeculativeAlgorithm.NONE
395
+ ), "CPUGraphRunner does not support speculative inference yet."
396
+ # TODO add compile support for encoder-decoder models
397
+ assert (
398
+ not self.is_encoder_decoder
399
+ ), "CPUGraphRunner does not support encoder-decoder models yet."
400
+ assert self.dp_size == 1, "CPUGraphRunner does not support DP yet."
401
+ assert self.pp_size == 1, "CPUGraphRunner does not support PP yet."
402
+
403
+ # Batch sizes to capture
404
+ self.capture_bs = get_batch_sizes_to_capture(model_runner)
405
+ log_info_on_rank0(logger, f"Capture cpu graph bs {self.capture_bs}")
406
+ # Attention backend
407
+ self.max_bs = max(self.capture_bs)
408
+ self.max_num_token = self.max_bs * self.num_tokens_per_bs
409
+
410
+ self.seq_len_fill_value = (
411
+ self.model_runner.attn_backend.get_graph_seq_len_fill_value()
412
+ )
413
+
414
+ if self.enable_torch_compile:
415
+ register_fake_ops()
416
+ set_torch_compile_config()
417
+
418
+ # Graph inputs
419
+ with torch.device(self.device):
420
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
421
+ self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int64)
422
+ self.seq_lens = torch.full(
423
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int64
424
+ )
425
+ self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
426
+ self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
427
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
428
+ self.num_token_non_padded = torch.zeros((1,), dtype=torch.int64)
429
+ self.custom_mask = torch.ones(
430
+ (
431
+ (self.seq_lens.sum().item() + self.max_num_token)
432
+ * self.num_tokens_per_bs
433
+ ),
434
+ dtype=torch.bool,
435
+ device=self.device,
436
+ )
437
+
438
+ # Capture
439
+ try:
440
+ self.capture()
441
+ except RuntimeError as e:
442
+ raise Exception(
443
+ f"Capture CPU graph failed: {e}\n{CPU_GRAPH_CAPTURE_FAILED_MSG}"
444
+ )
445
+
446
+ def can_run(self, forward_batch: ForwardBatch):
447
+ is_bs_supported = forward_batch.batch_size in self.graphs
448
+
449
+ requested_capture_hidden_mode = max(
450
+ forward_batch.capture_hidden_mode,
451
+ (
452
+ forward_batch.spec_info.capture_hidden_mode
453
+ if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
454
+ is not None
455
+ else CaptureHiddenMode.NULL
456
+ ),
457
+ )
458
+ capture_hidden_mode_matches = (
459
+ requested_capture_hidden_mode == CaptureHiddenMode.NULL
460
+ or requested_capture_hidden_mode == self.capture_hidden_mode
461
+ )
462
+
463
+ return is_bs_supported and capture_hidden_mode_matches
464
+
465
+ def capture(self) -> None:
466
+ capture_range = (
467
+ tqdm.tqdm(list(reversed(self.capture_bs)))
468
+ if get_tensor_model_parallel_rank() == 0
469
+ else reversed(self.capture_bs)
470
+ )
471
+ for bs in capture_range:
472
+ if get_tensor_model_parallel_rank() == 0:
473
+ avail_mem = psutil.virtual_memory().available / (1 << 30)
474
+ capture_range.set_description(
475
+ f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
476
+ )
477
+
478
+ with patch_model(
479
+ self.model_runner.model,
480
+ bs in self.capture_bs,
481
+ num_tokens=bs * self.num_tokens_per_bs,
482
+ tp_group=self.model_runner.tp_group,
483
+ ) as forward:
484
+ (
485
+ graph,
486
+ output_buffers,
487
+ ) = self.capture_one_batch_size(bs, forward)
488
+ self.graphs[bs] = graph
489
+ self.output_buffers[bs] = output_buffers
490
+
491
+ def capture_one_batch_size(self, bs: int, forward: Callable):
492
+ num_tokens = bs * self.num_tokens_per_bs
493
+
494
+ # Graph inputs
495
+ input_ids = self.input_ids[:num_tokens]
496
+ req_pool_indices = self.req_pool_indices[:bs]
497
+ seq_lens = self.seq_lens[:bs]
498
+ out_cache_loc = self.out_cache_loc[:num_tokens]
499
+ positions = self.positions[:num_tokens]
500
+ mrope_positions = self.mrope_positions[:, :bs]
501
+ self.num_token_non_padded[...] = num_tokens
502
+
503
+ spec_info = self.get_spec_info(num_tokens)
504
+ if self.capture_hidden_mode != CaptureHiddenMode.FULL:
505
+ self.capture_hidden_mode = (
506
+ spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
507
+ )
508
+
509
+ forward_batch = ForwardBatch(
510
+ forward_mode=self.capture_forward_mode,
511
+ batch_size=bs,
512
+ input_ids=input_ids,
513
+ req_pool_indices=req_pool_indices,
514
+ seq_lens=seq_lens,
515
+ req_to_token_pool=self.model_runner.req_to_token_pool,
516
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
517
+ attn_backend=self.model_runner.attn_backend,
518
+ out_cache_loc=out_cache_loc,
519
+ seq_lens_sum=seq_lens.sum().item(),
520
+ return_logprob=False,
521
+ positions=positions,
522
+ mrope_positions=mrope_positions,
523
+ spec_algorithm=self.model_runner.spec_algorithm,
524
+ spec_info=spec_info,
525
+ capture_hidden_mode=self.capture_hidden_mode,
526
+ num_token_non_padded=self.num_token_non_padded,
527
+ global_forward_mode=self.capture_forward_mode,
528
+ )
529
+
530
+ # Attention backend
531
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
532
+ # Do infernence to avoid setting attr at runtime, e.g.,
533
+ # self.attn_mha.kv_b_proj = self.kv_b_proj for full graph compile on CPU
534
+ self.model_runner.model.forward(
535
+ forward_batch.input_ids,
536
+ forward_batch.positions,
537
+ forward_batch,
538
+ )
539
+
540
+ # Run and capture
541
+ def run_once():
542
+ # Clean intermediate result cache for DP attention
543
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
544
+ logits_output_or_pp_proxy_tensors = forward(
545
+ input_ids,
546
+ forward_batch.positions,
547
+ forward_batch,
548
+ )
549
+ return logits_output_or_pp_proxy_tensors
550
+
551
+ with torch.no_grad():
552
+ for _ in range(2):
553
+ self.model_runner.tp_group.barrier()
554
+ out = run_once()
555
+ return forward, out
556
+
557
+ def recapture_if_needed(self, forward_batch: ForwardBatch):
558
+
559
+ # If the required capture_hidden_mode changes, we need to recapture the graph
560
+
561
+ # These are the different factors that can influence the capture_hidden_mode
562
+ capture_hidden_mode_required_by_forward_batch = (
563
+ forward_batch.capture_hidden_mode
564
+ )
565
+ capture_hidden_mode_required_by_spec_info = getattr(
566
+ forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
567
+ )
568
+ capture_hidden_mode_required_for_returning_hidden_states = (
569
+ CaptureHiddenMode.FULL
570
+ if self.model_runner.server_args.enable_return_hidden_states
571
+ else CaptureHiddenMode.NULL
572
+ )
573
+
574
+ # Determine the highest capture_hidden_mode required
575
+ # (If we have FULL, we can emulate LAST or NULL)
576
+ # (If we have LAST, we can emulate NULL)
577
+ required_capture_hidden_mode = max(
578
+ capture_hidden_mode_required_by_forward_batch,
579
+ capture_hidden_mode_required_by_spec_info,
580
+ capture_hidden_mode_required_for_returning_hidden_states,
581
+ )
582
+
583
+ # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
584
+ if self.capture_hidden_mode != required_capture_hidden_mode:
585
+ self.capture_hidden_mode = required_capture_hidden_mode
586
+ self.capture()
587
+
588
+ # TODO add padding support for CPUGraphRunner
589
+ def replay(
590
+ self,
591
+ forward_batch: ForwardBatch,
592
+ skip_attn_backend_init: bool = False,
593
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
594
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
595
+ assert (
596
+ pp_proxy_tensors is None
597
+ ), "PPProxyTensors is not supported in CPUGraphRunner yet."
598
+ self.recapture_if_needed(forward_batch)
599
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
600
+ output = self.graphs[forward_batch.batch_size](
601
+ forward_batch.input_ids,
602
+ forward_batch.positions,
603
+ forward_batch,
604
+ )
605
+ return output
606
+
607
+ def get_spec_info(self, num_tokens: int):
608
+ spec_info = None
609
+ if self.model_runner.spec_algorithm.is_eagle():
610
+ from sglang.srt.speculative.eagle_utils import EagleVerifyInput
611
+
612
+ if self.model_runner.is_draft_worker:
613
+ raise RuntimeError("This should not happen.")
614
+ else:
615
+ spec_info = EagleVerifyInput(
616
+ draft_token=None,
617
+ custom_mask=self.custom_mask,
618
+ positions=None,
619
+ retrive_index=None,
620
+ retrive_next_token=None,
621
+ retrive_next_sibling=None,
622
+ retrive_cum_len=None,
623
+ spec_steps=self.model_runner.server_args.speculative_num_steps,
624
+ topk=self.model_runner.server_args.speculative_eagle_topk,
625
+ draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
626
+ capture_hidden_mode=CaptureHiddenMode.FULL,
627
+ seq_lens_sum=None,
628
+ seq_lens_cpu=None,
629
+ )
630
+
631
+ return spec_info
632
+
633
+
634
+ CPU_GRAPH_CAPTURE_FAILED_MSG = (
635
+ "Possible solutions:\n"
636
+ "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
637
+ "2. set --torch-compile-max-bs to a smaller value (e.g., 8)\n"
638
+ "3. disable torch compile by not using --enable-torch-compile\n"
639
+ "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
640
+ )
@@ -271,7 +271,10 @@ class CudaGraphRunner:
271
271
  self.capture_forward_mode = ForwardMode.DECODE
272
272
  self.capture_hidden_mode = CaptureHiddenMode.NULL
273
273
  self.num_tokens_per_bs = 1
274
- if model_runner.spec_algorithm.is_eagle():
274
+ if (
275
+ model_runner.spec_algorithm.is_eagle()
276
+ or model_runner.spec_algorithm.is_standalone()
277
+ ):
275
278
  if self.model_runner.is_draft_worker:
276
279
  raise RuntimeError("This should not happen")
277
280
  else:
@@ -317,7 +320,9 @@ class CudaGraphRunner:
317
320
  (self.max_num_token,), dtype=self._cache_loc_dtype()
318
321
  )
319
322
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
320
- self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
323
+ self.mrope_positions = torch.zeros(
324
+ (3, self.max_num_token), dtype=torch.int64
325
+ )
321
326
  self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
322
327
  self.tbo_plugin = TboCudaGraphRunnerPlugin()
323
328
 
@@ -532,7 +537,7 @@ class CudaGraphRunner:
532
537
  encoder_lens = self.encoder_lens[:bs]
533
538
  else:
534
539
  encoder_lens = None
535
- mrope_positions = self.mrope_positions[:, :bs]
540
+ mrope_positions = self.mrope_positions[:, :num_tokens]
536
541
  next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
537
542
  self.num_token_non_padded[...] = num_tokens
538
543
 
@@ -751,7 +756,7 @@ class CudaGraphRunner:
751
756
  if self.is_encoder_decoder:
752
757
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
753
758
  if forward_batch.mrope_positions is not None:
754
- self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
759
+ self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
755
760
  if self.require_gathered_buffer:
756
761
  self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
757
762
  self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
@@ -825,7 +830,10 @@ class CudaGraphRunner:
825
830
 
826
831
  def get_spec_info(self, num_tokens: int):
827
832
  spec_info = None
828
- if self.model_runner.spec_algorithm.is_eagle():
833
+ if (
834
+ self.model_runner.spec_algorithm.is_eagle()
835
+ or self.model_runner.spec_algorithm.is_standalone()
836
+ ):
829
837
  from sglang.srt.speculative.eagle_utils import EagleVerifyInput
830
838
 
831
839
  if self.model_runner.is_draft_worker: