sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
1
+ #pragma once
2
+
3
+ #include <cstddef>
4
+ #include <cstdint>
5
+ #include <functional>
6
+ #include <list>
7
+ #include <mutex>
8
+ #include <set>
9
+ #include <sstream>
10
+ #include <thread>
11
+ #include <tuple>
12
+ #include <unordered_map>
13
+ #include <vector>
14
+
15
+ #include "param.h"
16
+ #include "queue.h"
17
+
18
+ namespace ngram {
19
+
20
+ struct TrieNode {
21
+ std::unordered_map<int32_t, TrieNode*> child;
22
+ std::list<TrieNode*>::const_iterator global_lru_pos;
23
+ std::list<TrieNode*>::const_iterator parent_lru_pos;
24
+ int32_t token;
25
+ TrieNode* parent;
26
+ std::list<TrieNode*> lru;
27
+ int32_t freq = 0;
28
+
29
+ struct CompareByFreq {
30
+ bool operator()(TrieNode* a, TrieNode* b) const {
31
+ return std::tie(b->freq, a->token, a) < std::tie(a->freq, b->token, b);
32
+ }
33
+ };
34
+ std::multiset<TrieNode*, CompareByFreq> sorted_children;
35
+ };
36
+
37
+ class Ngram {
38
+ std::vector<TrieNode> nodes_;
39
+ std::vector<TrieNode*> node_pool_;
40
+ size_t free_node_count_;
41
+ std::list<TrieNode*> global_lru_;
42
+ TrieNode* root_;
43
+ std::vector<TrieNode*> path_;
44
+ Param param_;
45
+
46
+ std::vector<std::pair<TrieNode*, int32_t>> match(const std::vector<int32_t>& tokens, size_t batch_size) const;
47
+
48
+ void squeeze(size_t count);
49
+
50
+ TrieNode* getNode() {
51
+ auto node = node_pool_[--free_node_count_];
52
+ node->~TrieNode();
53
+ new (node) TrieNode();
54
+ return node;
55
+ }
56
+
57
+ mutable std::mutex mutex_;
58
+ bool quit_flag_;
59
+ utils::Queue<std::vector<int32_t>> insert_queue_;
60
+ std::thread insert_worker_;
61
+ std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
62
+
63
+ public:
64
+ Ngram(size_t capacity, const Param& param);
65
+ Ngram() = default;
66
+ ~Ngram();
67
+
68
+ static Ngram& instance() {
69
+ static Ngram instance;
70
+ return instance;
71
+ }
72
+
73
+ void synchronize() const;
74
+
75
+ void asyncInsert(std::vector<std::vector<int32_t>>&& tokens);
76
+
77
+ struct Result {
78
+ std::vector<int32_t> token;
79
+ std::vector<uint8_t> mask;
80
+
81
+ void truncate(size_t n);
82
+ };
83
+
84
+ Result batchMatch(const std::vector<std::vector<int32_t>>& tokens) const;
85
+
86
+ void reset() {
87
+ std::unique_lock<std::mutex> lock(mutex_);
88
+
89
+ global_lru_.clear();
90
+ path_.clear();
91
+ node_pool_.clear();
92
+ for (auto& node : nodes_) {
93
+ node_pool_.emplace_back(&node);
94
+ }
95
+ free_node_count_ = node_pool_.size();
96
+ root_ = getNode();
97
+ }
98
+
99
+ const Param& param() const {
100
+ return param_;
101
+ }
102
+
103
+ private:
104
+ Result matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const;
105
+ Result matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const;
106
+
107
+ void insert();
108
+ };
109
+
110
+ } // namespace ngram
@@ -0,0 +1,138 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import logging
4
+ import os
5
+ from typing import List, Tuple
6
+
7
+ import numpy as np
8
+ from torch.utils.cpp_extension import load
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _abs_path = os.path.dirname(os.path.abspath(__file__))
13
+ ngram_cache_cpp = load(
14
+ name="ngram_cache_cpp",
15
+ sources=[
16
+ f"{_abs_path}/ngram_cache_binding.cpp",
17
+ f"{_abs_path}/ngram.cpp",
18
+ ],
19
+ extra_cflags=["-O3", "-std=c++20"],
20
+ )
21
+
22
+
23
+ class NgramCache:
24
+ def __init__(
25
+ self,
26
+ branch_length=18,
27
+ min_match_window_size=1,
28
+ max_match_window_size=10,
29
+ min_bfs_breadth=1,
30
+ max_bfs_breadth=8,
31
+ draft_token_num=8,
32
+ match_type="BFS",
33
+ capacity=1000000,
34
+ ):
35
+ param = ngram_cache_cpp.Param()
36
+ param.branch_length = branch_length
37
+ param.min_match_window_size = min_match_window_size
38
+ param.max_match_window_size = max_match_window_size
39
+ param.min_bfs_breadth = min_bfs_breadth
40
+ param.max_bfs_breadth = max_bfs_breadth
41
+ param.draft_token_num = draft_token_num
42
+ param.match_type = match_type
43
+ self.cache = ngram_cache_cpp.Ngram(capacity, param)
44
+
45
+ self.default_mask = np.ones((1, 1), dtype=np.int64)
46
+ self.draft_token_num = draft_token_num
47
+
48
+ def batch_put(self, batch_tokens: List[List[int]]):
49
+ self.cache.asyncInsert(batch_tokens)
50
+
51
+ def synchronize(self):
52
+ self.cache.synchronize()
53
+
54
+ def reset(self):
55
+ self.cache.reset()
56
+
57
+ def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
58
+ result = self.cache.batchMatch(batch_tokens)
59
+ return np.array(result.token), np.array(result.mask)
60
+
61
+ def leaf_paths_from_mask(
62
+ self, tokens: List[int], tree_mask: List[List[int]]
63
+ ) -> List[List[int]]:
64
+ """
65
+ Find all leaf paths according to the binary tree_mask (i.e., paths that are not prefixes of any other path).
66
+
67
+ Args:
68
+ mask : List[List[int]] # nxn binary matrix
69
+ tokens : List[int] # token list corresponding to columns
70
+
71
+ Returns:
72
+ List[List[int]] # token lists of only the leaf paths, preserving their order of appearance
73
+ """
74
+
75
+ row_sets = [
76
+ (i, {idx for idx, v in enumerate(row) if v == 1})
77
+ for i, row in enumerate(tree_mask)
78
+ ]
79
+ leaf_sets = []
80
+ leaf_rows = []
81
+
82
+ for i, cur_set in reversed(row_sets):
83
+ if any(cur_set <= kept for kept in leaf_sets):
84
+ continue
85
+ leaf_sets.append(cur_set)
86
+ leaf_rows.append(i)
87
+
88
+ leaf_rows.reverse()
89
+ result = []
90
+ for r in leaf_rows:
91
+ path = [tokens[col] for col in range(len(tokens)) if tree_mask[r][col] == 1]
92
+ result.append(path)
93
+
94
+ return result
95
+
96
+ def debug_result(
97
+ self, decoding_ids: np.ndarray, decoding_masks: np.ndarray, tokenizer=None
98
+ ):
99
+ decoding_ids = decoding_ids.reshape(-1, self.draft_token_num)
100
+ decoding_masks = decoding_masks.reshape(
101
+ -1, self.draft_token_num, self.draft_token_num
102
+ )
103
+ logger.info(f"\n{decoding_ids=}\n{decoding_masks=}")
104
+ for i in range(decoding_ids.shape[0]):
105
+ leaf_paths = self.leaf_paths_from_mask(
106
+ decoding_ids[i].tolist(), decoding_masks[i].tolist()
107
+ )
108
+ if tokenizer is None:
109
+ logger.info(f"draft path {i}: {leaf_paths}")
110
+ else:
111
+ logger.info(f"result {i}:")
112
+ for leaf_path in leaf_paths:
113
+ logger.info(
114
+ f"draft path {i}: {leaf_path} -> {tokenizer.decode(leaf_path, ensure_ascii=False)}"
115
+ )
116
+
117
+
118
+ # main function
119
+ if __name__ == "__main__":
120
+ format = f"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
121
+ logging.basicConfig(
122
+ level=logging.DEBUG,
123
+ format=format,
124
+ datefmt="%Y-%m-%d %H:%M:%S",
125
+ force=True,
126
+ )
127
+
128
+ token_ids = [
129
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
130
+ [1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
131
+ ]
132
+ cache = NgramCache(branch_length=12, draft_token_num=8)
133
+ cache.batch_put(token_ids)
134
+
135
+ cache.synchronize()
136
+ decoding_ids, decoding_masks = cache.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]])
137
+
138
+ cache.debug_result(decoding_ids, decoding_masks)
@@ -0,0 +1,43 @@
1
+ #include <pybind11/pybind11.h>
2
+ #include <pybind11/stl.h>
3
+
4
+ #include "ngram.h"
5
+
6
+ PYBIND11_MODULE(ngram_cache_cpp, m) {
7
+ using namespace ngram;
8
+ namespace py = pybind11;
9
+ m.doc() = "";
10
+
11
+ py::class_<Ngram>(m, "Ngram")
12
+ .def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
13
+ .def("asyncInsert", &Ngram::asyncInsert, "")
14
+ .def("batchMatch", &Ngram::batchMatch, "")
15
+ .def("reset", &Ngram::reset, "")
16
+ .def("synchronize", &Ngram::synchronize, "");
17
+
18
+ py::class_<Param>(m, "Param")
19
+ .def(py::init<>())
20
+ .def_readwrite("enable", &Param::enable)
21
+ .def_readwrite("enable_router_mode", &Param::enable_router_mode)
22
+ .def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth)
23
+ .def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth)
24
+ .def_readwrite("min_match_window_size", &Param::min_match_window_size)
25
+ .def_readwrite("max_match_window_size", &Param::max_match_window_size)
26
+ .def_readwrite("branch_length", &Param::branch_length)
27
+ .def_readwrite("draft_token_num", &Param::draft_token_num)
28
+ .def_readwrite("match_type", &Param::match_type)
29
+ .def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size)
30
+ .def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num)
31
+ .def("get_draft_token_num", &Param::get_draft_token_num, "")
32
+ .def("get_min_match_window_size", &Param::get_min_match_window_size, "")
33
+ .def("parse", &Param::parse, "")
34
+ .def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "")
35
+ .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
36
+ .def("detail", &Param::detail, "");
37
+
38
+ py::class_<Ngram::Result>(m, "Result")
39
+ .def(py::init<>())
40
+ .def_readwrite("token", &Ngram::Result::token)
41
+ .def_readwrite("mask", &Ngram::Result::mask)
42
+ .def("truncate", &Ngram::Result::truncate);
43
+ }
@@ -0,0 +1,125 @@
1
+ #pragma once
2
+
3
+ #include <cstddef>
4
+ #include <iostream>
5
+ #include <limits>
6
+ #include <regex>
7
+ #include <sstream>
8
+ #include <stdexcept>
9
+ #include <string>
10
+ #include <vector>
11
+
12
+ namespace ngram {
13
+
14
+ struct Param {
15
+ bool enable;
16
+ bool enable_router_mode;
17
+ size_t min_bfs_breadth;
18
+ size_t max_bfs_breadth;
19
+ size_t min_match_window_size;
20
+ size_t max_match_window_size;
21
+ size_t branch_length;
22
+ size_t draft_token_num;
23
+ std::string match_type;
24
+
25
+ std::vector<size_t> batch_min_match_window_size;
26
+ std::vector<size_t> batch_draft_token_num;
27
+
28
+ size_t get_draft_token_num(size_t batch_size) const {
29
+ if (batch_size < batch_draft_token_num.size()) {
30
+ if (batch_draft_token_num[batch_size] !=
31
+ std::numeric_limits<decltype(batch_draft_token_num)::value_type>::max()) {
32
+ return batch_draft_token_num[batch_size];
33
+ }
34
+ }
35
+ return draft_token_num - 1;
36
+ }
37
+
38
+ size_t get_min_match_window_size(size_t batch_size) const {
39
+ if (batch_size < batch_min_match_window_size.size()) {
40
+ if (batch_min_match_window_size[batch_size] !=
41
+ std::numeric_limits<decltype(batch_min_match_window_size)::value_type>::max()) {
42
+ return batch_min_match_window_size[batch_size];
43
+ }
44
+ }
45
+ return min_match_window_size;
46
+ }
47
+
48
+ std::vector<size_t> parse(const std::string& value) {
49
+ // 0-1|10,2-3|20,
50
+ std::vector<size_t> result;
51
+ if (value.empty()) {
52
+ return result;
53
+ }
54
+ std::vector<size_t> mark;
55
+ std::regex comma_re(",");
56
+ std::sregex_token_iterator first{value.begin(), value.end(), comma_re, -1}, last;
57
+ for (auto p : std::vector<std::string>(first, last)) {
58
+ std::cerr << "seg " << p << std::endl;
59
+ }
60
+ for (const auto& seg : std::vector<std::string>(first, last)) {
61
+ std::regex pipe_re("\\|");
62
+ std::sregex_token_iterator seg_first{seg.begin(), seg.end(), pipe_re, -1}, seg_last;
63
+ std::vector<std::string> part(seg_first, seg_last);
64
+ for (auto p : part) {
65
+ std::cerr << "part " << p << std::endl;
66
+ }
67
+ if (part.size() != 2) {
68
+ throw std::runtime_error(
69
+ "failed to get config, invalid config: " + seg + ", part's size = " + std::to_string(part.size()));
70
+ }
71
+ std::regex endash_re("-");
72
+ std::sregex_token_iterator range_first{part[0].begin(), part[0].end(), endash_re, -1}, range_last;
73
+ std::vector<std::string> range(range_first, range_last);
74
+ if (range.size() != 2) {
75
+ throw std::runtime_error("failed to get range, invalid config: " + value);
76
+ }
77
+ size_t L = std::atoi(range[0].c_str());
78
+ size_t R = std::atoi(range[1].c_str());
79
+ if (L > R || R > 128) {
80
+ throw std::runtime_error("invalid range, config: " + value);
81
+ }
82
+ if (R >= result.size()) {
83
+ result.resize(R + 1, std::numeric_limits<decltype(result)::value_type>::max());
84
+ mark.resize(result.size(), false);
85
+ }
86
+ size_t config = std::atoi(part[1].c_str());
87
+ do {
88
+ if (mark[L]) {
89
+ throw std::runtime_error("repeated position " + std::to_string(L) + ", config : " + value);
90
+ }
91
+ mark[L] = true;
92
+ result[L] = config;
93
+ } while (++L <= R);
94
+ }
95
+ return result;
96
+ }
97
+
98
+ void resetBatchMinMatchWindowSize(const std::string& value) {
99
+ batch_min_match_window_size = parse(value);
100
+ }
101
+
102
+ void resetBatchReturnTokenNum(const std::string& value) {
103
+ batch_draft_token_num = parse(value);
104
+ }
105
+
106
+ std::string detail() {
107
+ std::stringstream ss;
108
+ ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
109
+ << ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
110
+ << ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
111
+ << ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num
112
+ << ", match_type = " << match_type;
113
+ ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
114
+ for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
115
+ ss << i << "|" << batch_min_match_window_size[i] << ",";
116
+ }
117
+ ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = ";
118
+ for (int i = 0; i < batch_draft_token_num.size(); ++i) {
119
+ ss << i << "|" << batch_draft_token_num[i] << ",";
120
+ }
121
+ return ss.str();
122
+ }
123
+ };
124
+
125
+ } // namespace ngram
@@ -0,0 +1,71 @@
1
+ #pragma once
2
+
3
+ #include <condition_variable>
4
+ #include <queue>
5
+
6
+ namespace utils {
7
+
8
+ template <typename T>
9
+ class Queue {
10
+ public:
11
+ bool enqueue(T&& rhs) {
12
+ {
13
+ std::lock_guard<std::mutex> lock(mutex_);
14
+ if (closed_) {
15
+ return false;
16
+ }
17
+ queue_.emplace(std::move(rhs));
18
+ }
19
+ cv_.notify_one();
20
+ return true;
21
+ }
22
+
23
+ bool enqueue(const T& rhs) {
24
+ {
25
+ std::lock_guard<std::mutex> lock(mutex_);
26
+ if (closed_) {
27
+ return false;
28
+ }
29
+ queue_.emplace(rhs);
30
+ }
31
+ cv_.notify_one();
32
+ return true;
33
+ }
34
+
35
+ bool dequeue(T& rhs) {
36
+ std::unique_lock<std::mutex> lock(mutex_);
37
+ cv_.wait(lock, [this] { return queue_.size() || closed_; });
38
+ if (closed_) {
39
+ return false;
40
+ }
41
+ rhs = std::move(queue_.front());
42
+ queue_.pop();
43
+ return true;
44
+ }
45
+
46
+ size_t size() const {
47
+ std::lock_guard<std::mutex> lock(mutex_);
48
+ return queue_.size();
49
+ }
50
+
51
+ bool empty() const {
52
+ std::lock_guard<std::mutex> lock(mutex_);
53
+ return queue_.empty();
54
+ }
55
+
56
+ void close() {
57
+ {
58
+ std::lock_guard<std::mutex> lock(mutex_);
59
+ closed_ = true;
60
+ }
61
+ cv_.notify_all();
62
+ }
63
+
64
+ private:
65
+ std::queue<T> queue_;
66
+ mutable std::mutex mutex_;
67
+ std::condition_variable cv_;
68
+ bool closed_{false};
69
+ };
70
+
71
+ } // namespace utils
@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
20
20
  ForwardBatch,
21
21
  ForwardMode,
22
22
  )
23
- from sglang.srt.speculative.eagle_utils import EagleDraftInput
23
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
24
24
  from sglang.srt.utils import (
25
25
  require_attn_tp_gather,
26
26
  require_gathered_buffer,
@@ -302,6 +302,7 @@ class EAGLEDraftCudaGraphRunner:
302
302
  if bs != raw_bs:
303
303
  self.seq_lens.fill_(self.seq_len_fill_value)
304
304
  self.out_cache_loc.zero_()
305
+ self.positions.zero_()
305
306
 
306
307
  num_tokens = bs * self.num_tokens_per_bs
307
308
 
@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
21
21
  ForwardBatch,
22
22
  ForwardMode,
23
23
  )
24
- from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
24
+ from sglang.srt.speculative.eagle_info import EagleDraftInput
25
+ from sglang.srt.speculative.spec_utils import fast_topk
25
26
  from sglang.srt.utils import (
26
27
  require_attn_tp_gather,
27
28
  require_gathered_buffer,
@@ -331,6 +332,7 @@ class EAGLEDraftExtendCudaGraphRunner:
331
332
  if bs * self.num_tokens_per_bs != num_tokens:
332
333
  self.seq_lens.fill_(self.seq_len_fill_value)
333
334
  self.out_cache_loc.zero_()
335
+ self.positions.zero_()
334
336
  self.accept_length.fill_(1)
335
337
  self.extend_seq_lens.fill_(1)
336
338