sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import enum
4
+
3
5
  # Copyright 2023-2024 SGLang Team
4
6
  # Licensed under the Apache License, Version 2.0 (the "License");
5
7
  # you may not use this file except in compliance with the License.
@@ -35,10 +37,11 @@ import copy
35
37
  import dataclasses
36
38
  import logging
37
39
  import threading
40
+ import time
38
41
  from enum import Enum, auto
39
42
  from http import HTTPStatus
40
43
  from itertools import chain
41
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
44
+ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
42
45
 
43
46
  import numpy as np
44
47
  import torch
@@ -51,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
51
54
  from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
52
55
  ScheduleBatchDisaggregationDecodeMixin,
53
56
  )
57
+ from sglang.srt.disaggregation.utils import DisaggregationMode
54
58
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
55
59
  from sglang.srt.mem_cache.allocator import (
56
60
  BaseTokenToKVPoolAllocator,
@@ -58,10 +62,10 @@ from sglang.srt.mem_cache.allocator import (
58
62
  )
59
63
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
60
64
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
61
- from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
62
65
  from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
66
+ from sglang.srt.mem_cache.radix_cache import RadixKey
63
67
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
64
- from sglang.srt.metrics.collector import TimeStats
68
+ from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
65
69
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
66
70
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
67
71
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -70,8 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
70
74
 
71
75
  if TYPE_CHECKING:
72
76
  from sglang.srt.configs.model_config import ModelConfig
73
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
74
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
77
+ from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
75
78
 
76
79
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
77
80
 
@@ -86,6 +89,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
86
89
  "disable_flashinfer_cutlass_moe_fp4_allgather",
87
90
  "disable_radix_cache",
88
91
  "enable_dp_lm_head",
92
+ "enable_fp32_lm_head",
89
93
  "flashinfer_mxfp4_moe_precision",
90
94
  "enable_flashinfer_allreduce_fusion",
91
95
  "moe_dense_tp_size",
@@ -107,6 +111,9 @@ GLOBAL_SERVER_ARGS_KEYS = [
107
111
  "enable_symm_mem",
108
112
  "enable_custom_logit_processor",
109
113
  "disaggregation_mode",
114
+ "enable_deterministic_inference",
115
+ "nsa_prefill",
116
+ "nsa_decode",
110
117
  ]
111
118
 
112
119
  # Put some global args for easy access
@@ -407,6 +414,23 @@ class MultimodalInputs:
407
414
  # other args would be kept intact
408
415
 
409
416
 
417
+ class RequestStage(str, enum.Enum):
418
+ # prefill
419
+ PREFILL_WAITING = "prefill_waiting"
420
+
421
+ # disaggregation prefill
422
+ PREFILL_PREPARE = "prefill_prepare"
423
+ PREFILL_BOOTSTRAP = "prefill_bootstrap"
424
+ PREFILL_FORWARD = "prefill_forward"
425
+ PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
426
+
427
+ # disaggregation decode
428
+ DECODE_PREPARE = "decode_prepare"
429
+ DECODE_BOOTSTRAP = "decode_bootstrap"
430
+ DECODE_WAITING = "decode_waiting"
431
+ DECODE_TRANSFERRED = "decode_transferred"
432
+
433
+
410
434
  class Req:
411
435
  """The input and output status of a request."""
412
436
 
@@ -431,8 +455,12 @@ class Req:
431
455
  bootstrap_host: Optional[str] = None,
432
456
  bootstrap_port: Optional[int] = None,
433
457
  bootstrap_room: Optional[int] = None,
458
+ disagg_mode: Optional[DisaggregationMode] = None,
434
459
  data_parallel_rank: Optional[int] = None,
435
460
  vocab_size: Optional[int] = None,
461
+ priority: Optional[int] = None,
462
+ metrics_collector: Optional[SchedulerMetricsCollector] = None,
463
+ extra_key: Optional[str] = None,
436
464
  ):
437
465
  # Input and output info
438
466
  self.rid = rid
@@ -465,6 +493,14 @@ class Req:
465
493
  self.sampling_params = sampling_params
466
494
  self.custom_logit_processor = custom_logit_processor
467
495
  self.return_hidden_states = return_hidden_states
496
+
497
+ # extra key for classifying the request (e.g. cache_salt)
498
+ if lora_id is not None:
499
+ extra_key = (
500
+ extra_key or ""
501
+ ) + lora_id # lora_id is concatenated to the extra key
502
+
503
+ self.extra_key = extra_key
468
504
  self.lora_id = lora_id
469
505
 
470
506
  # Memory pool info
@@ -483,6 +519,7 @@ class Req:
483
519
  self.stream = stream
484
520
  self.eos_token_ids = eos_token_ids
485
521
  self.vocab_size = vocab_size
522
+ self.priority = priority
486
523
 
487
524
  # For incremental decoding
488
525
  # ----- | --------- read_ids -------|
@@ -512,6 +549,8 @@ class Req:
512
549
  self.host_hit_length = 0
513
550
  # The node to lock until for swa radix tree lock ref
514
551
  self.swa_uuid_for_lock: Optional[int] = None
552
+ # The prefix length of the last prefix matching
553
+ self.last_matched_prefix_len: int = 0
515
554
 
516
555
  # Whether or not if it is chunked. It increments whenever
517
556
  # it is chunked, and decrement whenever chunked request is
@@ -573,6 +612,8 @@ class Req:
573
612
  ) = None
574
613
  self.hidden_states: List[List[float]] = []
575
614
  self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
615
+ self.output_topk_p = None
616
+ self.output_topk_index = None
576
617
 
577
618
  # Embedding (return values)
578
619
  self.embedding = None
@@ -590,10 +631,10 @@ class Req:
590
631
  self.spec_verify_ct = 0
591
632
 
592
633
  # For metrics
593
- self.time_stats: TimeStats = TimeStats()
634
+ self.metrics_collector = metrics_collector
635
+ self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
594
636
  self.has_log_time_stats: bool = False
595
- self.queue_time_start = None
596
- self.queue_time_end = None
637
+ self.last_tic = time.monotonic()
597
638
 
598
639
  # For disaggregation
599
640
  self.bootstrap_host: str = bootstrap_host
@@ -624,7 +665,21 @@ class Req:
624
665
  @property
625
666
  def is_prefill_only(self) -> bool:
626
667
  """Check if this request is prefill-only (no token generation needed)."""
627
- return self.sampling_params.max_new_tokens == 0
668
+ # NOTE: when spec is enabled, prefill_only optimizations are disabled
669
+ return (
670
+ self.sampling_params.max_new_tokens == 0
671
+ and global_server_args_dict["speculative_algorithm"] is None
672
+ )
673
+
674
+ def add_latency(self, stage: RequestStage):
675
+ if self.metrics_collector is None:
676
+ return
677
+
678
+ now = time.monotonic()
679
+ self.metrics_collector.observe_per_stage_req_latency(
680
+ stage.value, now - self.last_tic
681
+ )
682
+ self.last_tic = now
628
683
 
629
684
  def extend_image_inputs(self, image_inputs):
630
685
  if self.multimodal_inputs is None:
@@ -642,26 +697,17 @@ class Req:
642
697
  ):
643
698
  self.fill_ids = self.origin_input_ids + self.output_ids
644
699
  if tree_cache is not None:
645
- if isinstance(tree_cache, LoRARadixCache):
646
- (
647
- self.prefix_indices,
648
- self.last_node,
649
- self.last_host_node,
650
- self.host_hit_length,
651
- ) = tree_cache.match_prefix_with_lora_id(
652
- key=LoRAKey(
653
- lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
654
- ),
655
- )
656
- else:
657
- (
658
- self.prefix_indices,
659
- self.last_node,
660
- self.last_host_node,
661
- self.host_hit_length,
662
- ) = tree_cache.match_prefix(
663
- key=self.adjust_max_prefix_ids(),
664
- )
700
+ (
701
+ self.prefix_indices,
702
+ self.last_node,
703
+ self.last_host_node,
704
+ self.host_hit_length,
705
+ ) = tree_cache.match_prefix(
706
+ key=RadixKey(
707
+ token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
708
+ ),
709
+ )
710
+ self.last_matched_prefix_len = len(self.prefix_indices)
665
711
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
666
712
 
667
713
  def adjust_max_prefix_ids(self):
@@ -794,10 +840,10 @@ class Req:
794
840
  return
795
841
 
796
842
  if self.bootstrap_room is not None:
797
- prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
843
+ prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
798
844
  else:
799
- prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.get_type().value})"
800
- logger.info(f"{prefix}: {self.time_stats}")
845
+ prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
846
+ logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
801
847
  self.has_log_time_stats = True
802
848
 
803
849
  def set_finish_with_abort(self, error_msg: str):
@@ -820,10 +866,6 @@ class Req:
820
866
  )
821
867
 
822
868
 
823
- # Batch id
824
- bid = 0
825
-
826
-
827
869
  @dataclasses.dataclass
828
870
  class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
829
871
  """Store all information of a batch on the scheduler."""
@@ -860,6 +902,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
860
902
  token_type_ids: torch.Tensor = None # shape: [b], int64
861
903
  req_pool_indices: torch.Tensor = None # shape: [b], int64
862
904
  seq_lens: torch.Tensor = None # shape: [b], int64
905
+ seq_lens_cpu: torch.Tensor = None # shape: [b], int64
863
906
  # The output locations of the KV cache
864
907
  out_cache_loc: torch.Tensor = None # shape: [b], int64
865
908
  output_ids: torch.Tensor = None # shape: [b], int64
@@ -915,7 +958,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
915
958
 
916
959
  # Speculative decoding
917
960
  spec_algorithm: SpeculativeAlgorithm = None
918
- spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
961
+ # spec_info: Optional[SpecInput] = None
962
+ spec_info: Optional[SpecInput] = None
919
963
 
920
964
  # Whether to return hidden states
921
965
  return_hidden_states: bool = False
@@ -1014,7 +1058,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1014
1058
  def alloc_paged_token_slots_extend(
1015
1059
  self,
1016
1060
  prefix_lens: torch.Tensor,
1061
+ prefix_lens_cpu: torch.Tensor,
1017
1062
  seq_lens: torch.Tensor,
1063
+ seq_lens_cpu: torch.Tensor,
1018
1064
  last_loc: torch.Tensor,
1019
1065
  extend_num_tokens: int,
1020
1066
  backup_state: bool = False,
@@ -1022,7 +1068,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1022
1068
  # Over estimate the number of tokens: assume each request needs a new page.
1023
1069
  num_tokens = (
1024
1070
  extend_num_tokens
1025
- + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
1071
+ + len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
1026
1072
  )
1027
1073
  self._evict_tree_cache_if_needed(num_tokens)
1028
1074
 
@@ -1030,7 +1076,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1030
1076
  state = self.token_to_kv_pool_allocator.backup_state()
1031
1077
 
1032
1078
  out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
1033
- prefix_lens, seq_lens, last_loc, extend_num_tokens
1079
+ prefix_lens,
1080
+ prefix_lens_cpu,
1081
+ seq_lens,
1082
+ seq_lens_cpu,
1083
+ last_loc,
1084
+ extend_num_tokens,
1034
1085
  )
1035
1086
  if out_cache_loc is None:
1036
1087
  error_msg = (
@@ -1049,6 +1100,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1049
1100
  def alloc_paged_token_slots_decode(
1050
1101
  self,
1051
1102
  seq_lens: torch.Tensor,
1103
+ seq_lens_cpu: torch.Tensor,
1052
1104
  last_loc: torch.Tensor,
1053
1105
  backup_state: bool = False,
1054
1106
  ):
@@ -1059,7 +1111,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1059
1111
  if backup_state:
1060
1112
  state = self.token_to_kv_pool_allocator.backup_state()
1061
1113
 
1062
- out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
1114
+ out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
1115
+ seq_lens, seq_lens_cpu, last_loc
1116
+ )
1063
1117
  if out_cache_loc is None:
1064
1118
  error_msg = (
1065
1119
  f"Decode out of memory. Try to lower your batch size.\n"
@@ -1128,6 +1182,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1128
1182
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
1129
1183
  self.device, non_blocking=True
1130
1184
  )
1185
+ self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1131
1186
 
1132
1187
  if not decoder_out_cache_loc:
1133
1188
  self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
@@ -1176,12 +1231,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1176
1231
  seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
1177
1232
  self.device, non_blocking=True
1178
1233
  )
1234
+ seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
1179
1235
  orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
1180
1236
  self.device, non_blocking=True
1181
1237
  )
1182
1238
  prefix_lens_tensor = torch.tensor(
1183
1239
  prefix_lens, dtype=torch.int64, device=self.device
1184
1240
  )
1241
+ prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
1185
1242
 
1186
1243
  token_type_ids_tensor = None
1187
1244
  if len(token_type_ids) > 0:
@@ -1308,13 +1365,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1308
1365
  prefix_lens_tensor,
1309
1366
  )
1310
1367
  out_cache_loc = self.alloc_paged_token_slots_extend(
1311
- prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
1368
+ prefix_lens_tensor,
1369
+ prefix_lens_cpu_tensor,
1370
+ seq_lens_tensor,
1371
+ seq_lens_cpu,
1372
+ last_loc,
1373
+ extend_num_tokens,
1312
1374
  )
1313
1375
 
1314
1376
  # Set fields
1315
1377
  self.input_ids = input_ids_tensor
1316
1378
  self.req_pool_indices = req_pool_indices_tensor
1317
1379
  self.seq_lens = seq_lens_tensor
1380
+ self.seq_lens_cpu = seq_lens_cpu
1318
1381
  self.orig_seq_lens = orig_seq_lens_tensor
1319
1382
  self.out_cache_loc = out_cache_loc
1320
1383
  self.input_embeds = (
@@ -1457,7 +1520,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1457
1520
  )
1458
1521
 
1459
1522
  retracted_reqs = []
1460
- seq_lens_cpu = self.seq_lens.cpu().numpy()
1461
1523
  first_iter = True
1462
1524
  while first_iter or (
1463
1525
  not self.check_decode_mem(selected_indices=sorted_indices)
@@ -1484,37 +1546,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1484
1546
  idx = sorted_indices.pop()
1485
1547
  req = self.reqs[idx]
1486
1548
  retracted_reqs.append(req)
1487
-
1488
- if server_args.disaggregation_mode == "decode":
1489
- req.offload_kv_cache(
1490
- self.req_to_token_pool, self.token_to_kv_pool_allocator
1491
- )
1492
-
1493
- if isinstance(self.tree_cache, ChunkCache):
1494
- # ChunkCache does not have eviction
1495
- token_indices = self.req_to_token_pool.req_to_token[
1496
- req.req_pool_idx, : seq_lens_cpu[idx]
1497
- ]
1498
- self.token_to_kv_pool_allocator.free(token_indices)
1499
- self.req_to_token_pool.free(req.req_pool_idx)
1500
- else:
1501
- # TODO: apply more fine-grained retraction
1502
- last_uncached_pos = (
1503
- len(req.prefix_indices) // server_args.page_size
1504
- ) * server_args.page_size
1505
- token_indices = self.req_to_token_pool.req_to_token[
1506
- req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1507
- ]
1508
- self.token_to_kv_pool_allocator.free(token_indices)
1509
- self.req_to_token_pool.free(req.req_pool_idx)
1510
-
1511
- # release the last node
1512
- if self.is_hybrid:
1513
- self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1514
- else:
1515
- self.tree_cache.dec_lock_ref(req.last_node)
1516
-
1517
- req.reset_for_retract()
1549
+ self.release_req(idx, len(sorted_indices), server_args)
1518
1550
 
1519
1551
  if len(retracted_reqs) == 0:
1520
1552
  # Corner case: only one request left
@@ -1533,7 +1565,45 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1533
1565
  ) / total_max_new_tokens
1534
1566
  new_estimate_ratio = min(1.0, new_estimate_ratio)
1535
1567
 
1536
- return retracted_reqs, new_estimate_ratio
1568
+ return retracted_reqs, new_estimate_ratio, []
1569
+
1570
+ def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
1571
+ req = self.reqs[idx]
1572
+ seq_lens_cpu = self.seq_lens_cpu.numpy()
1573
+
1574
+ if server_args.disaggregation_mode == "decode":
1575
+ req.offload_kv_cache(
1576
+ self.req_to_token_pool, self.token_to_kv_pool_allocator
1577
+ )
1578
+ if isinstance(self.tree_cache, ChunkCache):
1579
+ # ChunkCache does not have eviction
1580
+ token_indices = self.req_to_token_pool.req_to_token[
1581
+ req.req_pool_idx, : seq_lens_cpu[idx]
1582
+ ]
1583
+ self.token_to_kv_pool_allocator.free(token_indices)
1584
+ self.req_to_token_pool.free(req.req_pool_idx)
1585
+ else:
1586
+ # TODO: apply more fine-grained retraction
1587
+ last_uncached_pos = (
1588
+ len(req.prefix_indices) // server_args.page_size
1589
+ ) * server_args.page_size
1590
+ token_indices = self.req_to_token_pool.req_to_token[
1591
+ req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
1592
+ ]
1593
+ self.token_to_kv_pool_allocator.free(token_indices)
1594
+ self.req_to_token_pool.free(req.req_pool_idx)
1595
+
1596
+ # release the last node
1597
+ if self.is_hybrid:
1598
+ self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
1599
+ else:
1600
+ self.tree_cache.dec_lock_ref(req.last_node)
1601
+
1602
+ # NOTE(lsyin): we should use the newly evictable memory instantly.
1603
+ num_tokens = remaing_req_count * global_config.retract_decode_steps
1604
+ self._evict_tree_cache_if_needed(num_tokens)
1605
+
1606
+ req.reset_for_retract()
1537
1607
 
1538
1608
  def prepare_encoder_info_decode(self):
1539
1609
  # Reset the encoder cached status
@@ -1543,6 +1613,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1543
1613
  self.forward_mode = ForwardMode.IDLE
1544
1614
  self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
1545
1615
  self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1616
+ self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
1546
1617
  self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1547
1618
  self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
1548
1619
  self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
@@ -1557,7 +1628,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1557
1628
  self.forward_mode = ForwardMode.DECODE
1558
1629
  bs = len(self.reqs)
1559
1630
 
1560
- if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
1631
+ if (
1632
+ self.spec_algorithm.is_eagle()
1633
+ or self.spec_algorithm.is_standalone()
1634
+ or self.spec_algorithm.is_ngram()
1635
+ ):
1561
1636
  # if spec decoding is used, the decode batch is prepared inside
1562
1637
  # `forward_batch_speculative_generation` after running draft models.
1563
1638
  return
@@ -1598,10 +1673,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1598
1673
  if self.enable_overlap:
1599
1674
  # Do not use in-place operations in the overlap mode
1600
1675
  self.seq_lens = self.seq_lens + 1
1676
+ self.seq_lens_cpu = self.seq_lens_cpu + 1
1601
1677
  self.orig_seq_lens = self.orig_seq_lens + 1
1602
1678
  else:
1603
1679
  # A faster in-place version
1604
1680
  self.seq_lens.add_(1)
1681
+ self.seq_lens_cpu.add_(1)
1605
1682
  self.orig_seq_lens.add_(1)
1606
1683
  self.seq_lens_sum += bs
1607
1684
 
@@ -1620,7 +1697,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1620
1697
  self.req_pool_indices, self.seq_lens - 2
1621
1698
  ]
1622
1699
  self.out_cache_loc = self.alloc_paged_token_slots_decode(
1623
- self.seq_lens, last_loc
1700
+ self.seq_lens, self.seq_lens_cpu, last_loc
1624
1701
  )
1625
1702
 
1626
1703
  self.req_to_token_pool.write(
@@ -1666,6 +1743,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1666
1743
  self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
1667
1744
  self.req_pool_indices = self.req_pool_indices[keep_indices_device]
1668
1745
  self.seq_lens = self.seq_lens[keep_indices_device]
1746
+ self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
1669
1747
  self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
1670
1748
  self.out_cache_loc = None
1671
1749
  self.seq_lens_sum = self.seq_lens.sum().item()
@@ -1683,7 +1761,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1683
1761
 
1684
1762
  self.sampling_info.filter_batch(keep_indices, keep_indices_device)
1685
1763
  if self.spec_info:
1686
- self.spec_info.filter_batch(keep_indices_device)
1764
+ if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
1765
+ has_been_filtered = False
1766
+ else:
1767
+ has_been_filtered = True
1768
+ self.spec_info.filter_batch(
1769
+ new_indices=keep_indices_device,
1770
+ has_been_filtered=has_been_filtered,
1771
+ )
1687
1772
 
1688
1773
  def merge_batch(self, other: "ScheduleBatch"):
1689
1774
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
@@ -1699,6 +1784,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1699
1784
  [self.req_pool_indices, other.req_pool_indices]
1700
1785
  )
1701
1786
  self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
1787
+ self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
1702
1788
  self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
1703
1789
  self.out_cache_loc = None
1704
1790
  self.seq_lens_sum += other.seq_lens_sum
@@ -1742,15 +1828,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1742
1828
  self.sampling_info.grammars = None
1743
1829
 
1744
1830
  seq_lens_cpu = (
1745
- seq_lens_cpu_cache
1746
- if seq_lens_cpu_cache is not None
1747
- else self.seq_lens.cpu()
1831
+ seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
1748
1832
  )
1749
1833
 
1750
- global bid
1751
- bid += 1
1752
1834
  return ModelWorkerBatch(
1753
- bid=bid,
1754
1835
  forward_mode=self.forward_mode,
1755
1836
  input_ids=self.input_ids,
1756
1837
  req_pool_indices=self.req_pool_indices,
@@ -1870,8 +1951,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1870
1951
 
1871
1952
  @dataclasses.dataclass
1872
1953
  class ModelWorkerBatch:
1873
- # The batch id
1874
- bid: int
1875
1954
  # The forward mode
1876
1955
  forward_mode: ForwardMode
1877
1956
  # The input ids
@@ -1932,7 +2011,9 @@ class ModelWorkerBatch:
1932
2011
 
1933
2012
  # Speculative decoding
1934
2013
  spec_algorithm: SpeculativeAlgorithm = None
1935
- spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
2014
+
2015
+ spec_info: Optional[SpecInput] = None
2016
+
1936
2017
  # If set, the output of the batch contains the hidden states of the run.
1937
2018
  capture_hidden_mode: CaptureHiddenMode = None
1938
2019
  hicache_consumer_index: int = -1