sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,245 @@
1
+ import logging
2
+ from typing import List, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
7
+
8
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
9
+ from sglang.srt.managers.tp_worker import TpModelWorker
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput, ForwardMode
11
+ from sglang.srt.server_args import ServerArgs
12
+ from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
13
+ from sglang.srt.speculative.ngram_utils import NgramVerifyInput
14
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ USE_FULL_MASK = True
19
+
20
+
21
+ class NGRAMWorker:
22
+ def __init__(
23
+ self,
24
+ server_args: ServerArgs,
25
+ gpu_id: int,
26
+ tp_rank: int,
27
+ dp_rank: Optional[int],
28
+ moe_ep_rank: int,
29
+ nccl_port: int,
30
+ target_worker: TpModelWorker,
31
+ ):
32
+ self.target_worker = target_worker
33
+ self.model_runner = target_worker.model_runner
34
+ self.tp_rank = tp_rank
35
+ self.page_size = server_args.page_size
36
+ self.draft_token_num: int = server_args.speculative_num_draft_tokens
37
+ self.branch_length: int = server_args.speculative_ngram_branch_length
38
+ self.max_match_window_size: int = (
39
+ server_args.speculative_ngram_max_match_window_size
40
+ )
41
+
42
+ self.max_batch_size = target_worker.max_running_requests
43
+ self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
44
+
45
+ self._init_preallocated_tensors()
46
+
47
+ self.ngram_cache = NgramCache(
48
+ min_match_window_size=server_args.speculative_ngram_min_match_window_size,
49
+ max_match_window_size=server_args.speculative_ngram_max_match_window_size,
50
+ min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
51
+ max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
52
+ capacity=server_args.speculative_ngram_capacity,
53
+ branch_length=server_args.speculative_ngram_branch_length,
54
+ draft_token_num=server_args.speculative_num_draft_tokens,
55
+ )
56
+
57
+ def clear_cache_pool(self):
58
+ self.ngram_cache.reset()
59
+
60
+ def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
61
+ seq2_len = len(seq2)
62
+ if seq2_len >= n:
63
+ return seq2[-n:]
64
+
65
+ need_from_seq1 = n - seq2_len
66
+ return seq1[-need_from_seq1:] + seq2
67
+
68
+ def _init_preallocated_tensors(self):
69
+ max_total_drafts = self.max_batch_size * self.draft_token_num
70
+ max_total_mask_size = (
71
+ self.max_batch_size * self.draft_token_num * self.draft_token_num
72
+ )
73
+
74
+ self.draft_tokens = torch.empty(
75
+ (max_total_drafts,), dtype=torch.int64, device=self.device
76
+ )
77
+ self.retrieve_indexes = torch.empty(
78
+ (self.max_batch_size, self.draft_token_num),
79
+ dtype=torch.int64,
80
+ device=self.device,
81
+ )
82
+ self.retrive_next_token = torch.empty(
83
+ (self.max_batch_size, self.draft_token_num),
84
+ dtype=torch.int64,
85
+ device=self.device,
86
+ )
87
+ self.retrive_next_sibling = torch.empty(
88
+ (self.max_batch_size, self.draft_token_num),
89
+ dtype=torch.int64,
90
+ device=self.device,
91
+ )
92
+ self.positions = torch.empty(
93
+ (max_total_drafts,), dtype=torch.int64, device=self.device
94
+ )
95
+ self.tree_mask = torch.empty(
96
+ (max_total_mask_size,), dtype=torch.bool, device=self.device
97
+ )
98
+
99
+ self.draft_tokens_batch = []
100
+ self.tree_mask_batch = []
101
+ self.retrieve_indexes_batch = []
102
+ self.retrive_next_token_batch = []
103
+ self.retrive_next_sibling_batch = []
104
+ self.positions_batch = []
105
+
106
+ for bs in range(0, self.max_batch_size + 1):
107
+ self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :])
108
+ self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :])
109
+ self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :])
110
+ self.positions_batch.append(self.positions[: bs * self.draft_token_num])
111
+ self.draft_tokens_batch.append(
112
+ self.draft_tokens[: bs * self.draft_token_num]
113
+ )
114
+ self.tree_mask_batch.append(
115
+ self.tree_mask[: bs * self.draft_token_num * self.draft_token_num]
116
+ )
117
+
118
+ def _prepare_draft_tokens(
119
+ self, batch: ScheduleBatch
120
+ ) -> tuple[np.ndarray, np.ndarray]:
121
+ bs = batch.batch_size()
122
+
123
+ self.ngram_cache.synchronize()
124
+ batch_tokens = []
125
+ for req in batch.reqs:
126
+ check_token = self._efficient_concat_last_n(
127
+ req.origin_input_ids, req.output_ids, self.max_match_window_size
128
+ )
129
+ batch_tokens.append(check_token)
130
+ req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
131
+ total_draft_token_num = len(req_drafts)
132
+
133
+ # Check if speculative decoding is needed; here we always enforce it
134
+ assert (
135
+ total_draft_token_num == bs * self.draft_token_num
136
+ ), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}"
137
+ return req_drafts, mask
138
+
139
+ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch):
140
+ if batch.forward_mode.is_extend():
141
+ return
142
+
143
+ bs = batch.batch_size()
144
+
145
+ retrive_index = self.retrieve_indexes_batch[bs]
146
+ retrive_next_token = self.retrive_next_token_batch[bs]
147
+ retrive_next_sibling = self.retrive_next_sibling_batch[bs]
148
+ positions = self.positions_batch[bs]
149
+ tree_mask = self.tree_mask_batch[bs]
150
+ draft_tokens = self.draft_tokens_batch[bs]
151
+
152
+ req_drafts, mask = self._prepare_draft_tokens(batch)
153
+ tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
154
+ draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
155
+
156
+ reconstruct_indices_from_tree_mask(
157
+ tree_mask,
158
+ batch.seq_lens,
159
+ positions, # mutable
160
+ retrive_index, # mutable
161
+ retrive_next_token, # mutable
162
+ retrive_next_sibling, # mutable
163
+ bs,
164
+ self.draft_token_num,
165
+ )
166
+
167
+ # NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer.
168
+ # Testing shows about 8% performance improvement (the effect is roughly proportional to batch size).
169
+ if USE_FULL_MASK:
170
+ tree_mask = []
171
+ mask = mask.reshape(
172
+ batch.batch_size(), self.draft_token_num, self.draft_token_num
173
+ )
174
+ for i, req in enumerate(batch.reqs):
175
+ seq_len = len(req.origin_input_ids) + len(req.output_ids)
176
+ req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
177
+ req_mask = torch.cat(
178
+ (req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
179
+ ).to(torch.bool)
180
+ tree_mask.append(req_mask.flatten())
181
+ tree_mask = torch.cat(tree_mask, dim=0)
182
+
183
+ batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
184
+ batch.forward_mode = ForwardMode.TARGET_VERIFY
185
+ batch.spec_info = NgramVerifyInput(
186
+ draft_tokens,
187
+ tree_mask,
188
+ positions,
189
+ retrive_index,
190
+ retrive_next_token,
191
+ retrive_next_sibling,
192
+ self.draft_token_num,
193
+ )
194
+ batch.spec_info.prepare_for_verify(batch, self.page_size)
195
+
196
+ def _update_ngram_cache(self, batch: ScheduleBatch):
197
+ batch_tokens = []
198
+ for req in batch.reqs:
199
+ # FIXME: Whether to insert 'extend' into the cache or not, after testing,
200
+ # there is not much difference, so we will not insert it for now.
201
+ # if batch.forward_mode.is_extend():
202
+ # put_ids = req.origin_input_ids + req.output_ids
203
+ # else:
204
+ put_ids = self._efficient_concat_last_n(
205
+ req.origin_input_ids, req.output_ids, self.branch_length
206
+ )
207
+ batch_tokens.append(put_ids)
208
+ self.ngram_cache.batch_put(batch_tokens)
209
+
210
+ def forward_batch_generation(self, batch: ScheduleBatch) -> ForwardBatchOutput:
211
+ self._prepare_for_speculative_decoding(batch)
212
+ model_worker_batch = batch.get_model_worker_batch()
213
+ num_accepted_tokens = 0
214
+
215
+ if model_worker_batch.forward_mode.is_target_verify():
216
+ forward_batch_output = self.target_worker.forward_batch_generation(
217
+ model_worker_batch, is_verify=True
218
+ )
219
+ logits_output, can_run_cuda_graph = (
220
+ forward_batch_output.logits_output,
221
+ forward_batch_output.can_run_cuda_graph,
222
+ )
223
+ verify_input = model_worker_batch.spec_info
224
+ logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
225
+ batch, logits_output, self.page_size
226
+ )
227
+ self._update_ngram_cache(batch)
228
+ batch.forward_mode = ForwardMode.DECODE
229
+
230
+ else:
231
+ forward_batch_output = self.target_worker.forward_batch_generation(
232
+ model_worker_batch
233
+ )
234
+ logits_output, next_token_ids, can_run_cuda_graph = (
235
+ forward_batch_output.logits_output,
236
+ forward_batch_output.next_token_ids,
237
+ forward_batch_output.can_run_cuda_graph,
238
+ )
239
+
240
+ return ForwardBatchOutput(
241
+ logits_output=logits_output,
242
+ next_token_ids=next_token_ids,
243
+ num_accepted_tokens=num_accepted_tokens,
244
+ can_run_cuda_graph=can_run_cuda_graph,
245
+ )
@@ -1,4 +1,8 @@
1
+ from abc import ABC, abstractmethod
1
2
  from enum import IntEnum, auto
3
+ from typing import List, Tuple
4
+
5
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
2
6
 
3
7
 
4
8
  class SpeculativeAlgorithm(IntEnum):
@@ -6,6 +10,7 @@ class SpeculativeAlgorithm(IntEnum):
6
10
  EAGLE = auto()
7
11
  EAGLE3 = auto()
8
12
  STANDALONE = auto()
13
+ NGRAM = auto()
9
14
 
10
15
  def is_none(self):
11
16
  return self == SpeculativeAlgorithm.NONE
@@ -19,14 +24,56 @@ class SpeculativeAlgorithm(IntEnum):
19
24
  def is_standalone(self):
20
25
  return self == SpeculativeAlgorithm.STANDALONE
21
26
 
27
+ def is_ngram(self):
28
+ return self == SpeculativeAlgorithm.NGRAM
29
+
22
30
  @staticmethod
23
31
  def from_string(name: str):
24
32
  name_map = {
25
33
  "EAGLE": SpeculativeAlgorithm.EAGLE,
26
34
  "EAGLE3": SpeculativeAlgorithm.EAGLE3,
27
35
  "STANDALONE": SpeculativeAlgorithm.STANDALONE,
36
+ "NGRAM": SpeculativeAlgorithm.NGRAM,
28
37
  None: SpeculativeAlgorithm.NONE,
29
38
  }
30
39
  if name is not None:
31
40
  name = name.upper()
32
41
  return name_map[name]
42
+
43
+
44
+ class SpecInputType(IntEnum):
45
+ # NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
46
+ # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
47
+ EAGLE_DRAFT = auto()
48
+ EAGLE_VERIFY = auto()
49
+ NGRAM_VERIFY = auto()
50
+
51
+
52
+ class SpecInput(ABC):
53
+ def __init__(self, spec_input_type: SpecInputType):
54
+ self.spec_input_type = spec_input_type
55
+
56
+ def is_draft_input(self) -> bool:
57
+ # FIXME: remove this function which is only used for assertion
58
+ # or use another variable name like `draft_input` to substitute `spec_info`
59
+ return self.spec_input_type == SpecInputType.EAGLE_DRAFT
60
+
61
+ def is_verify_input(self) -> bool:
62
+ return self.spec_input_type in {
63
+ SpecInputType.EAGLE_VERIFY,
64
+ SpecInputType.NGRAM_VERIFY,
65
+ }
66
+
67
+ @abstractmethod
68
+ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
69
+ pass
70
+
71
+ def get_spec_adjusted_global_num_tokens(
72
+ self, forward_batch: ModelWorkerBatch
73
+ ) -> Tuple[List[int], List[int]]:
74
+ c1, c2 = self.get_spec_adjust_token_coefficient()
75
+ global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
76
+ global_num_tokens_for_logprob = [
77
+ x * c2 for x in forward_batch.global_num_tokens_for_logprob
78
+ ]
79
+ return global_num_tokens, global_num_tokens_for_logprob