sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,606 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from typing import TYPE_CHECKING, List
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
13
+ from sglang.srt.environ import envs
14
+ from sglang.srt.managers.schedule_batch import Req
15
+ from sglang.srt.utils import is_cuda, is_hip
16
+
17
+ if is_cuda():
18
+ from sgl_kernel import fast_topk
19
+ elif is_hip():
20
+ from sgl_kernel import fast_topk
21
+
22
+ if TYPE_CHECKING:
23
+ from sglang.srt.speculative.eagle_info import EagleVerifyInput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # Simulate acceptance length for benchmarking purposes
29
+ SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
30
+ SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
31
+
32
+ TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
33
+ TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now
34
+
35
+
36
+ @triton.jit
37
+ def create_extend_after_decode_spec_info(
38
+ verified_id,
39
+ seq_lens,
40
+ accept_lens,
41
+ positions,
42
+ new_verified_id,
43
+ bs_upper: tl.constexpr,
44
+ ):
45
+ pid = tl.program_id(axis=0)
46
+ offsets = tl.arange(0, bs_upper)
47
+ seq_length = tl.load(seq_lens + pid)
48
+ accept_length = tl.load(accept_lens + pid)
49
+
50
+ accept_len_cumsum = tl.sum(
51
+ tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
52
+ )
53
+ positions_ptr = positions + accept_len_cumsum
54
+ mask = offsets < accept_length
55
+ tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
56
+
57
+ accept_len_cumsum += accept_length - 1
58
+ verified_id_data = tl.load(verified_id + accept_len_cumsum)
59
+ tl.store(new_verified_id + pid, verified_id_data)
60
+
61
+
62
+ @triton.jit
63
+ def assign_req_to_token_pool(
64
+ req_pool_indices,
65
+ req_to_token,
66
+ start_offset,
67
+ end_offset,
68
+ out_cache_loc,
69
+ pool_len: tl.constexpr,
70
+ bs_upper: tl.constexpr,
71
+ ):
72
+ BLOCK_SIZE: tl.constexpr = 32
73
+ pid = tl.program_id(axis=0)
74
+ kv_start = tl.load(start_offset + pid)
75
+ kv_end = tl.load(end_offset + pid)
76
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
77
+
78
+ length_offset = tl.arange(0, bs_upper)
79
+ start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
80
+ end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
81
+ out_offset = tl.sum(end - start, axis=0)
82
+
83
+ out_cache_ptr = out_cache_loc + out_offset
84
+
85
+ save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
86
+ load_offset = tl.arange(0, BLOCK_SIZE)
87
+
88
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
89
+ for _ in range(num_loop):
90
+ mask = save_offset < kv_end
91
+ data = tl.load(out_cache_ptr + load_offset, mask=mask)
92
+ tl.store(token_pool + save_offset, data, mask=mask)
93
+ save_offset += BLOCK_SIZE
94
+ load_offset += BLOCK_SIZE
95
+
96
+
97
+ @triton.jit
98
+ def assign_draft_cache_locs(
99
+ req_pool_indices,
100
+ req_to_token,
101
+ seq_lens,
102
+ extend_lens,
103
+ num_new_pages_per_topk,
104
+ out_cache_loc,
105
+ pool_len: tl.constexpr,
106
+ topk: tl.constexpr,
107
+ speculative_num_steps: tl.constexpr,
108
+ page_size: tl.constexpr,
109
+ bs_upper: tl.constexpr,
110
+ iter_upper: tl.constexpr,
111
+ ):
112
+ BLOCK_SIZE: tl.constexpr = 128
113
+ pid = tl.program_id(axis=0)
114
+
115
+ if page_size == 1 or topk == 1:
116
+ copy_len = topk * speculative_num_steps
117
+ out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
118
+ else:
119
+ bs_offset = tl.arange(0, bs_upper)
120
+ copy_len = tl.load(extend_lens + pid)
121
+ cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
122
+ out_cache_ptr = out_cache_loc + cum_copy_len
123
+
124
+ # Part 1: Copy from out_cache_loc to req_to_token
125
+ kv_start = tl.load(seq_lens + pid)
126
+ token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
127
+ num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
128
+ for i in range(num_loop):
129
+ copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
130
+ mask = copy_offset < copy_len
131
+ data = tl.load(out_cache_ptr + copy_offset, mask=mask)
132
+ tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
133
+
134
+ if page_size == 1 or topk == 1:
135
+ return
136
+
137
+ # Part 2: Copy the indices for the last partial page
138
+ prefix_len = tl.load(seq_lens + pid)
139
+ last_page_len = prefix_len % page_size
140
+ offsets = tl.arange(0, page_size)
141
+ mask = offsets < last_page_len
142
+ num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
143
+ prefix_base = token_pool + prefix_len - last_page_len
144
+
145
+ for topk_id in range(topk):
146
+ value = tl.load(prefix_base + offsets, mask=mask)
147
+ tl.store(
148
+ prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
149
+ value,
150
+ mask=mask,
151
+ )
152
+
153
+ # Part 3: Remove the padding in out_cache_loc
154
+ iter_offest = tl.arange(0, iter_upper)
155
+ for topk_id in range(topk):
156
+ indices = tl.load(
157
+ prefix_base
158
+ + topk_id * num_new_pages_per_topk_ * page_size
159
+ + last_page_len
160
+ + iter_offest,
161
+ mask=iter_offest < speculative_num_steps,
162
+ )
163
+ tl.store(
164
+ out_cache_loc
165
+ + pid * topk * speculative_num_steps
166
+ + topk_id * speculative_num_steps
167
+ + iter_offest,
168
+ indices,
169
+ mask=iter_offest < speculative_num_steps,
170
+ )
171
+
172
+
173
+ @triton.jit
174
+ def generate_draft_decode_kv_indices(
175
+ req_pool_indices,
176
+ req_to_token,
177
+ paged_kernel_lens,
178
+ kv_indices,
179
+ kv_indptr,
180
+ positions,
181
+ pool_len: tl.constexpr,
182
+ kv_indices_stride: tl.constexpr,
183
+ kv_indptr_stride: tl.constexpr,
184
+ bs_upper: tl.constexpr,
185
+ iter_upper: tl.constexpr,
186
+ num_tokens_upper: tl.constexpr,
187
+ page_size: tl.constexpr,
188
+ ):
189
+ BLOCK_SIZE: tl.constexpr = 128
190
+ iters = tl.program_id(axis=0)
191
+ bid = tl.program_id(axis=1)
192
+ topk_id = tl.program_id(axis=2)
193
+
194
+ num_steps = tl.num_programs(axis=0)
195
+ num_seqs = tl.num_programs(axis=1)
196
+ topk = tl.num_programs(axis=2)
197
+
198
+ kv_indices += kv_indices_stride * iters
199
+ kv_indptr += kv_indptr_stride * iters
200
+ iters += 1
201
+
202
+ load_offset = tl.arange(0, bs_upper)
203
+ seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
204
+ seq_len = tl.load(paged_kernel_lens + bid)
205
+ cum_seq_len = tl.sum(seq_lens)
206
+
207
+ # Update kv_indices
208
+ kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
209
+ kv_ptr = kv_indices + kv_offset
210
+ token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
211
+
212
+ kv_offset = tl.arange(0, BLOCK_SIZE)
213
+ num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
214
+ for _ in range(num_loop):
215
+ mask = kv_offset < seq_len
216
+ data = tl.load(token_pool_ptr + kv_offset, mask=mask)
217
+ tl.store(kv_ptr + kv_offset, data, mask=mask)
218
+ kv_offset += BLOCK_SIZE
219
+
220
+ extend_offset = tl.arange(0, iter_upper)
221
+ if page_size == 1 or topk == 1:
222
+ extend_data = tl.load(
223
+ token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
224
+ mask=extend_offset < iters,
225
+ )
226
+ else:
227
+ prefix_len = seq_len
228
+ last_page_len = prefix_len % page_size
229
+ num_new_pages_per_topk = (
230
+ last_page_len + num_steps + page_size - 1
231
+ ) // page_size
232
+ prefix_base = seq_len // page_size * page_size
233
+ start = (
234
+ prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
235
+ )
236
+ extend_data = tl.load(
237
+ token_pool_ptr + start + extend_offset,
238
+ mask=extend_offset < iters,
239
+ )
240
+
241
+ tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
242
+
243
+ # Update kv_indptr
244
+ bs_offset = tl.arange(0, num_tokens_upper)
245
+
246
+ zid = bid * topk + topk_id
247
+ if zid == 0:
248
+ zid = num_seqs * topk
249
+ positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
250
+ base = tl.sum(positions)
251
+ tl.store(kv_indptr + zid, base + zid * iters)
252
+
253
+
254
+ @triton.jit
255
+ def align_evict_mask_to_page_size(
256
+ seq_lens,
257
+ evict_mask,
258
+ page_size: tl.constexpr,
259
+ num_draft_tokens: tl.constexpr,
260
+ BLOCK_SIZE: tl.constexpr,
261
+ ):
262
+ t_range = tl.arange(0, BLOCK_SIZE)
263
+
264
+ bid = tl.program_id(axis=0)
265
+ seq_len = tl.load(seq_lens + bid)
266
+ io_mask = t_range < num_draft_tokens
267
+ mask_row = tl.load(
268
+ evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
269
+ )
270
+
271
+ num_trues = tl.sum(mask_row)
272
+ num_false = num_draft_tokens - num_trues
273
+
274
+ start = (seq_len + num_false - 1) // page_size * page_size - seq_len
275
+ for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
276
+ tl.store(evict_mask + bid * num_draft_tokens + i, False)
277
+
278
+
279
+ @triton.jit
280
+ def get_target_cache_loc(
281
+ tgt_cache_loc,
282
+ to_free_slots,
283
+ accept_length,
284
+ to_free_num_slots,
285
+ out_cache_loc,
286
+ num_verify_tokens: tl.constexpr,
287
+ num_verify_tokens_upper: tl.constexpr,
288
+ bs_upper: tl.constexpr,
289
+ ):
290
+ bid = tl.program_id(axis=0)
291
+ offset = tl.arange(0, num_verify_tokens_upper)
292
+ bs_offset = tl.arange(0, bs_upper)
293
+
294
+ # write the first part to tgt_cache_loc
295
+ accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
296
+ tgt_cache_loc_start = tl.sum(accept_len_all) + bid
297
+ copy_len = tl.load(accept_length + bid) + 1
298
+ out_cache_loc_row = tl.load(
299
+ out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
300
+ )
301
+ tl.store(
302
+ tgt_cache_loc + tgt_cache_loc_start + offset,
303
+ out_cache_loc_row,
304
+ mask=offset < copy_len,
305
+ )
306
+
307
+ # write the second part to to_free_num_pages
308
+ to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
309
+ to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
310
+ out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
311
+ to_free_slots_start = tl.sum(to_free_num_slots_all)
312
+
313
+ copy_len = to_free_num_slots_cur
314
+ out_cache_loc_row = tl.load(
315
+ out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
316
+ mask=offset < copy_len,
317
+ )
318
+ tl.store(
319
+ to_free_slots + to_free_slots_start + offset,
320
+ out_cache_loc_row,
321
+ mask=offset < copy_len,
322
+ )
323
+
324
+
325
+ @torch.compile(dynamic=True)
326
+ def get_src_tgt_cache_loc(
327
+ seq_lens: torch.Tensor,
328
+ out_cache_loc: torch.Tensor,
329
+ accept_index: torch.Tensor,
330
+ accept_length: torch.Tensor,
331
+ draft_token_num: int,
332
+ page_size: int,
333
+ ):
334
+ src_cache_loc = out_cache_loc[accept_index]
335
+ tgt_cache_loc = torch.empty_like(src_cache_loc)
336
+ extended_len = seq_lens + draft_token_num
337
+ keep_len = torch.minimum(
338
+ (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
339
+ extended_len,
340
+ )
341
+ to_free_num_slots = extended_len - keep_len
342
+ return src_cache_loc, tgt_cache_loc, to_free_num_slots
343
+
344
+
345
+ @triton.jit
346
+ def filter_finished_cache_loc_kernel(
347
+ out_cache_loc,
348
+ tgt_cache_loc,
349
+ accept_length,
350
+ accept_length_filter,
351
+ bs_upper: tl.constexpr,
352
+ num_verify_tokens_upper: tl.constexpr,
353
+ ):
354
+ bid = tl.program_id(0)
355
+ bs_offset = tl.arange(0, bs_upper)
356
+
357
+ accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
358
+ old_start = tl.sum(accept_length_all) + bid
359
+
360
+ accept_length_filter_all = tl.load(
361
+ accept_length_filter + bs_offset, mask=bs_offset < bid
362
+ )
363
+ new_start = tl.sum(accept_length_filter_all)
364
+
365
+ copy_len = tl.load(accept_length_filter + bid)
366
+ copy_offset = tl.arange(0, num_verify_tokens_upper)
367
+ value = tl.load(
368
+ tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
369
+ )
370
+ tl.store(
371
+ out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
372
+ )
373
+
374
+
375
+ @torch.compile(dynamic=True)
376
+ def create_accept_length_filter(
377
+ accept_length: torch.Tensor,
378
+ unfinished_index_device: torch.Tensor,
379
+ seq_lens: torch.Tensor,
380
+ ):
381
+ accept_length_filter = torch.zeros_like(accept_length)
382
+ accept_length_filter[unfinished_index_device] = (
383
+ accept_length[unfinished_index_device] + 1
384
+ )
385
+ seq_lens.add_(accept_length + 1)
386
+ return accept_length_filter
387
+
388
+
389
+ @torch.compile(dynamic=True)
390
+ def select_top_k_tokens(
391
+ i: int,
392
+ topk_p: torch.Tensor,
393
+ topk_index: torch.Tensor,
394
+ hidden_states: torch.Tensor,
395
+ scores: torch.Tensor,
396
+ topk: int,
397
+ ):
398
+ if i == 0:
399
+ # The first step after extend
400
+ input_ids = topk_index.flatten()
401
+ hidden_states = hidden_states.repeat_interleave(topk, dim=0)
402
+ scores = topk_p # shape: (b, topk)
403
+
404
+ tree_info = (
405
+ topk_p.unsqueeze(1), # shape: (b, 1, topk)
406
+ topk_index, # shape: (b, topk)
407
+ torch.arange(-1, topk, dtype=torch.long, device="cuda")
408
+ .unsqueeze(0)
409
+ .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
410
+ )
411
+ else:
412
+ # The later decode steps
413
+ expand_scores = torch.mul(
414
+ scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
415
+ ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
416
+ topk_cs_p, topk_cs_index = fast_topk(
417
+ expand_scores.flatten(start_dim=1), topk, dim=-1
418
+ ) # (b, topk)
419
+ scores = topk_cs_p # shape: (b, topk)
420
+
421
+ topk_index = topk_index.reshape(-1, topk**2)
422
+ input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
423
+
424
+ if hidden_states.shape[0] > 0:
425
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
426
+ 0, hidden_states.shape[0], step=topk, device="cuda"
427
+ ).repeat_interleave(topk)
428
+ hidden_states = hidden_states[selected_input_index, :]
429
+
430
+ tree_info = (
431
+ expand_scores, # shape: (b, topk, topk)
432
+ topk_index, # shape: (b, topk * topk)
433
+ topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
434
+ )
435
+
436
+ return input_ids, hidden_states, scores, tree_info
437
+
438
+
439
+ def _generate_simulated_accept_index(
440
+ accept_index,
441
+ predict,
442
+ accept_length,
443
+ bs,
444
+ spec_steps,
445
+ simulate_acc_len: float = SIMULATE_ACC_LEN,
446
+ simulate_acc_method: str = SIMULATE_ACC_METHOD,
447
+ ):
448
+ assert simulate_acc_len > 0.0
449
+
450
+ if simulate_acc_method == "multinomial":
451
+ simulated_values = torch.normal(
452
+ mean=simulate_acc_len,
453
+ std=1.0,
454
+ size=(1,),
455
+ device="cpu",
456
+ )
457
+ # clamp simulated values to be between 1 and self.spec_steps
458
+ simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
459
+ simulate_acc_len = int(simulated_values.round().item())
460
+ elif simulate_acc_method == "match-expected":
461
+ # multinomial sampling does not match the expected length
462
+ # we keep it for the sake of compatibility of existing tests
463
+ # but it's better to use "match-expected" for the cases that need to
464
+ # match the expected length, One caveat is that this will only sample
465
+ # either round down or round up of the expected length
466
+ simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
467
+ lower = int(simulate_acc_len // 1)
468
+ upper = lower + 1 if lower < spec_steps + 1 else lower
469
+ if lower == upper:
470
+ simulate_acc_len = lower
471
+ else:
472
+ weight_upper = simulate_acc_len - lower
473
+ weight_lower = 1.0 - weight_upper
474
+ probs = torch.tensor([weight_lower, weight_upper], device="cpu")
475
+ sampled_index = torch.multinomial(probs, num_samples=1)
476
+ simulate_acc_len = lower if sampled_index == 0 else upper
477
+ else:
478
+ raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
479
+
480
+ accept_indx_first_col = accept_index[:, 0].view(-1, 1)
481
+ sim_accept_index = torch.full(
482
+ (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
483
+ )
484
+ sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
485
+ simulate_acc_len, device=accept_index.device
486
+ )
487
+ accept_length.fill_(simulate_acc_len - 1)
488
+ predict.fill_(100) # some legit token id
489
+ return sim_accept_index
490
+
491
+
492
+ def traverse_tree(
493
+ retrieve_next_token: torch.Tensor,
494
+ retrieve_next_sibling: torch.Tensor,
495
+ draft_tokens: torch.Tensor,
496
+ grammar: BaseGrammarObject,
497
+ allocate_token_bitmask: torch.Tensor,
498
+ ):
499
+ """
500
+ Traverse the tree constructed by the draft model to generate the logits mask.
501
+ """
502
+ assert (
503
+ retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
504
+ )
505
+
506
+ allocate_token_bitmask.fill_(0)
507
+
508
+ def dfs(
509
+ curr: int,
510
+ retrieve_next_token: torch.Tensor,
511
+ retrieve_next_sibling: torch.Tensor,
512
+ parent_pos: int,
513
+ ):
514
+ if curr == 0:
515
+ # the first token generated by the target model, and thus it is always
516
+ # accepted from the previous iteration
517
+ accepted = True
518
+ else:
519
+ parent_bitmask = allocate_token_bitmask[parent_pos]
520
+ curr_token_id = draft_tokens[curr]
521
+ # 32 boolean bitmask values are packed into 32-bit integers
522
+ accepted = (
523
+ parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
524
+ ) != 0
525
+
526
+ if accepted:
527
+ if curr != 0:
528
+ # Accept the current token
529
+ grammar.accept_token(draft_tokens[curr])
530
+ if not grammar.is_terminated():
531
+ # Generate the bitmask for the current token
532
+ grammar.fill_vocab_mask(allocate_token_bitmask, curr)
533
+ if retrieve_next_token[curr] != -1:
534
+ # Visit the child node
535
+ dfs(
536
+ retrieve_next_token[curr],
537
+ retrieve_next_token,
538
+ retrieve_next_sibling,
539
+ curr,
540
+ )
541
+
542
+ if curr != 0:
543
+ # Rollback the current token
544
+ grammar.rollback(1)
545
+
546
+ if retrieve_next_sibling[curr] != -1:
547
+ # Visit the sibling node
548
+ dfs(
549
+ retrieve_next_sibling[curr],
550
+ retrieve_next_token,
551
+ retrieve_next_sibling,
552
+ parent_pos,
553
+ )
554
+
555
+ dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
556
+
557
+
558
+ def generate_token_bitmask(
559
+ reqs: List[Req],
560
+ verify_input: EagleVerifyInput,
561
+ retrieve_next_token_cpu: torch.Tensor,
562
+ retrieve_next_sibling_cpu: torch.Tensor,
563
+ draft_tokens_cpu: torch.Tensor,
564
+ vocab_size: int,
565
+ ):
566
+ """
567
+ Generate the logit mask for structured output.
568
+ Draft model's token can be either valid or invalid with respect to the grammar.
569
+ We need to perform DFS to
570
+ 1. figure out which tokens are accepted by the grammar.
571
+ 2. if so, what is the corresponding logit mask.
572
+ """
573
+
574
+ num_draft_tokens = draft_tokens_cpu.shape[-1]
575
+
576
+ allocate_token_bitmask = None
577
+ assert len(reqs) == retrieve_next_token_cpu.shape[0]
578
+ grammar = None
579
+ for i, req in enumerate(reqs):
580
+ if req.grammar is not None:
581
+ if allocate_token_bitmask is None:
582
+ allocate_token_bitmask = req.grammar.allocate_vocab_mask(
583
+ vocab_size=vocab_size,
584
+ batch_size=draft_tokens_cpu.numel(),
585
+ device="cpu",
586
+ )
587
+ grammar = req.grammar
588
+ s = time.perf_counter()
589
+ traverse_tree(
590
+ retrieve_next_token_cpu[i],
591
+ retrieve_next_sibling_cpu[i],
592
+ draft_tokens_cpu[i],
593
+ req.grammar,
594
+ allocate_token_bitmask[
595
+ i * num_draft_tokens : (i + 1) * num_draft_tokens
596
+ ],
597
+ )
598
+ tree_traverse_time = time.perf_counter() - s
599
+ if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
600
+ logger.warning(
601
+ f"Bit mask generation took {tree_traverse_time} seconds with "
602
+ f"grammar: {req.grammar}"
603
+ )
604
+
605
+ verify_input.grammar = grammar
606
+ return allocate_token_bitmask
@@ -1,8 +1,6 @@
1
1
  import logging
2
- import threading
3
- import time
4
2
  from abc import ABC
5
- from contextlib import contextmanager, nullcontext
3
+ from contextlib import contextmanager
6
4
 
7
5
  try:
8
6
  import torch_memory_saver
@@ -40,7 +38,7 @@ class TorchMemorySaverAdapter(ABC):
40
38
  def configure_subprocess(self):
41
39
  raise NotImplementedError
42
40
 
43
- def region(self, tag: str):
41
+ def region(self, tag: str, enable_cpu_backup: bool = False):
44
42
  raise NotImplementedError
45
43
 
46
44
  def pause(self, tag: str):
@@ -60,8 +58,8 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
60
58
  def configure_subprocess(self):
61
59
  return torch_memory_saver.configure_subprocess()
62
60
 
63
- def region(self, tag: str):
64
- return _memory_saver.region(tag=tag)
61
+ def region(self, tag: str, enable_cpu_backup: bool = False):
62
+ return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
65
63
 
66
64
  def pause(self, tag: str):
67
65
  return _memory_saver.pause(tag=tag)
@@ -80,7 +78,7 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
80
78
  yield
81
79
 
82
80
  @contextmanager
83
- def region(self, tag: str):
81
+ def region(self, tag: str, enable_cpu_backup: bool = False):
84
82
  yield
85
83
 
86
84
  def pause(self, tag: str):