sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ """
2
+ Configuration argument parser for command-line applications.
3
+ Handles merging of YAML configuration files with command-line arguments.
4
+ """
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Union
9
+
10
+ import yaml
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ConfigArgumentMerger:
16
+ """Handles merging of configuration file arguments with command-line arguments."""
17
+
18
+ def __init__(self, boolean_actions: List[str] = None):
19
+ """Initialize with list of boolean action destinations."""
20
+ self.boolean_actions = boolean_actions or []
21
+
22
+ def merge_config_with_args(self, cli_args: List[str]) -> List[str]:
23
+ """
24
+ Merge configuration file arguments with command-line arguments.
25
+
26
+ Configuration arguments are inserted after the subcommand to maintain
27
+ proper precedence: CLI > Config > Defaults
28
+
29
+ Args:
30
+ cli_args: List of command-line arguments
31
+
32
+ Returns:
33
+ Merged argument list with config values inserted
34
+
35
+ Raises:
36
+ ValueError: If multiple config files specified or no config file provided
37
+ """
38
+ config_file_path = self._extract_config_file_path(cli_args)
39
+ if not config_file_path:
40
+ return cli_args
41
+
42
+ config_args = self._parse_yaml_config(config_file_path)
43
+ return self._insert_config_args(cli_args, config_args, config_file_path)
44
+
45
+ def _extract_config_file_path(self, args: List[str]) -> str:
46
+ """Extract the config file path from arguments."""
47
+ config_indices = [i for i, arg in enumerate(args) if arg == "--config"]
48
+
49
+ if len(config_indices) > 1:
50
+ raise ValueError("Multiple config files specified! Only one allowed.")
51
+
52
+ if not config_indices:
53
+ return None
54
+
55
+ config_index = config_indices[0]
56
+ if config_index == len(args) - 1:
57
+ raise ValueError("No config file specified after --config flag!")
58
+
59
+ return args[config_index + 1]
60
+
61
+ def _insert_config_args(
62
+ self, cli_args: List[str], config_args: List[str], config_file_path: str
63
+ ) -> List[str]:
64
+ """Insert configuration arguments into the CLI argument list."""
65
+ config_index = cli_args.index("--config")
66
+
67
+ # Split arguments around config file
68
+ before_config = cli_args[:config_index]
69
+ after_config = cli_args[config_index + 2 :] # Skip --config and file path
70
+
71
+ # Simple merge: config args + CLI args
72
+ return config_args + before_config + after_config
73
+
74
+ def _parse_yaml_config(self, file_path: str) -> List[str]:
75
+ """
76
+ Parse YAML configuration file and convert to argument list.
77
+
78
+ Args:
79
+ file_path: Path to the YAML configuration file
80
+
81
+ Returns:
82
+ List of arguments in format ['--key', 'value', ...]
83
+
84
+ Raises:
85
+ ValueError: If file is not YAML or cannot be read
86
+ """
87
+ self._validate_yaml_file(file_path)
88
+
89
+ try:
90
+ with open(file_path, "r") as file:
91
+ config_data = yaml.safe_load(file)
92
+ except Exception as e:
93
+ logger.error(f"Failed to read config file {file_path}: {e}")
94
+ raise
95
+
96
+ # Handle empty files or None content
97
+ if config_data is None:
98
+ config_data = {}
99
+
100
+ if not isinstance(config_data, dict):
101
+ raise ValueError("Config file must contain a dictionary at root level")
102
+
103
+ return self._convert_config_to_args(config_data)
104
+
105
+ def _validate_yaml_file(self, file_path: str) -> None:
106
+ """Validate that the file is a YAML file."""
107
+ path = Path(file_path)
108
+ if path.suffix.lower() not in [".yaml", ".yml"]:
109
+ raise ValueError(f"Config file must be YAML format, got: {path.suffix}")
110
+
111
+ if not path.exists():
112
+ raise ValueError(f"Config file not found: {file_path}")
113
+
114
+ def _convert_config_to_args(self, config: Dict[str, Any]) -> List[str]:
115
+ """Convert configuration dictionary to argument list."""
116
+ args = []
117
+
118
+ for key, value in config.items():
119
+ if isinstance(value, bool):
120
+ self._add_boolean_arg(args, key, value)
121
+ elif isinstance(value, list):
122
+ self._add_list_arg(args, key, value)
123
+ else:
124
+ self._add_scalar_arg(args, key, value)
125
+
126
+ return args
127
+
128
+ def _add_boolean_arg(self, args: List[str], key: str, value: bool) -> None:
129
+ """Add boolean argument to the list."""
130
+ if key in self.boolean_actions:
131
+ # For boolean actions, always add the flag and value
132
+ args.extend([f"--{key}", str(value).lower()])
133
+ else:
134
+ # For regular booleans, only add flag if True
135
+ if value:
136
+ args.append(f"--{key}")
137
+
138
+ def _add_list_arg(self, args: List[str], key: str, value: List[Any]) -> None:
139
+ """Add list argument to the list."""
140
+ if value: # Only add if list is not empty
141
+ args.append(f"--{key}")
142
+ args.extend(str(item) for item in value)
143
+
144
+ def _add_scalar_arg(self, args: List[str], key: str, value: Any) -> None:
145
+ """Add scalar argument to the list."""
146
+ args.extend([f"--{key}", str(value)])
@@ -0,0 +1,151 @@
1
+ from dataclasses import dataclass
2
+ from typing import TYPE_CHECKING, Any, Callable, Optional
3
+
4
+ import torch
5
+
6
+ from sglang.srt.layers.moe import get_moe_runner_backend
7
+ from sglang.srt.layers.moe.utils import is_sbo_enabled
8
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
9
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
+ from sglang.srt.utils import get_int_env_var
12
+
13
+ if TYPE_CHECKING:
14
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
15
+
16
+
17
+ class SboFlags:
18
+ # TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ...
19
+
20
+ @classmethod
21
+ def enable_combine_down_gemm_two_stream_overlap(cls):
22
+ return (
23
+ is_sbo_enabled()
24
+ # currently only cutedsl backend supports it
25
+ and get_moe_runner_backend().is_flashinfer_cutedsl()
26
+ )
27
+
28
+ @classmethod
29
+ def enable_combine_shared_two_stream_overlap(cls):
30
+ return is_sbo_enabled()
31
+
32
+ @classmethod
33
+ def fuse_shared_experts_inside_sbo(cls):
34
+ # TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()`
35
+ return cls.enable_combine_shared_two_stream_overlap()
36
+
37
+
38
+ @dataclass
39
+ class CombineOverlapArgs:
40
+ # this "overlap" flag means overlapping with down gemm, not the general two-stream overlap
41
+ overlap: bool
42
+ stream: torch.cuda.Stream
43
+ wait_event: torch.cuda.Event
44
+ num_sms: int
45
+ signal: Optional[torch.Tensor] = None
46
+ threshold: int = -1
47
+
48
+
49
+ @dataclass
50
+ class DownGemmOverlapArgs:
51
+ num_sms: int
52
+ signal: torch.Tensor
53
+ start_event: torch.cuda.Event
54
+
55
+
56
+ def execute_sbo(
57
+ forward_shared_experts: Callable[[], Any],
58
+ experts: "DeepEPMoE",
59
+ hidden_states: torch.Tensor,
60
+ topk_idx: torch.Tensor,
61
+ topk_weights: torch.Tensor,
62
+ forward_batch: ForwardBatch,
63
+ alt_stream: Optional = None,
64
+ ):
65
+ shared_output = None
66
+
67
+ dispatch_output = experts.dispatch(
68
+ hidden_states, topk_idx, topk_weights, forward_batch
69
+ )
70
+
71
+ combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = (
72
+ _compute_overlap_args(dispatch_output, alt_stream)
73
+ )
74
+
75
+ hidden_states = experts.moe_impl(
76
+ dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
77
+ )
78
+ if (e := meta_overlap_args.get("record_event_after_down")) is not None:
79
+ e.record()
80
+
81
+ if SboFlags.enable_combine_shared_two_stream_overlap():
82
+ # TODO reduce sm for non-deepgemm
83
+ with deep_gemm_wrapper.configure_deep_gemm_num_sms(
84
+ meta_overlap_args["compute_num_sms"]
85
+ ):
86
+ shared_output = forward_shared_experts()
87
+
88
+ hidden_states = experts.combine(
89
+ hidden_states,
90
+ dispatch_output.topk_idx,
91
+ dispatch_output.topk_weights,
92
+ forward_batch,
93
+ overlap_args=combine_overlap_args,
94
+ )
95
+
96
+ return hidden_states, shared_output
97
+
98
+
99
+ def _compute_overlap_args(dispatch_output, alt_stream):
100
+ if not (
101
+ SboFlags.enable_combine_down_gemm_two_stream_overlap()
102
+ or SboFlags.enable_combine_shared_two_stream_overlap()
103
+ ):
104
+ return None, None, {}
105
+
106
+ hidden_states = dispatch_output.hidden_states_fp8
107
+ if isinstance(hidden_states, tuple):
108
+ hidden_states = hidden_states[0]
109
+
110
+ num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape
111
+
112
+ total_num_sms = torch.cuda.get_device_properties(
113
+ device="cuda"
114
+ ).multi_processor_count
115
+ communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32)
116
+ compute_num_sms = total_num_sms - communicate_num_sms
117
+
118
+ assert alt_stream is not None
119
+ combine_wait_event = torch.cuda.Event()
120
+ combine_overlap_args = CombineOverlapArgs(
121
+ overlap=False,
122
+ num_sms=communicate_num_sms,
123
+ stream=alt_stream,
124
+ wait_event=combine_wait_event,
125
+ )
126
+ meta_overlap_args = dict(
127
+ compute_num_sms=compute_num_sms,
128
+ )
129
+ down_gemm_overlap_args = None
130
+
131
+ if SboFlags.enable_combine_down_gemm_two_stream_overlap():
132
+ # TODO use zero_allocator to remove this `torch.zeros` call
133
+ # NOTE ours v2 use uint32 not int32 currently
134
+ combine_signal = torch.zeros(
135
+ num_local_experts, dtype=torch.uint32, device=hidden_states.device
136
+ )
137
+
138
+ down_gemm_overlap_args = DownGemmOverlapArgs(
139
+ signal=combine_signal,
140
+ start_event=combine_wait_event,
141
+ num_sms=compute_num_sms,
142
+ )
143
+ combine_overlap_args.overlap = True
144
+ combine_overlap_args.signal = combine_signal
145
+ combine_overlap_args.threshold = compute_num_sms
146
+ else:
147
+ meta_overlap_args |= dict(
148
+ record_event_after_down=combine_wait_event,
149
+ )
150
+
151
+ return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args
@@ -0,0 +1,374 @@
1
+ #include "ngram.h"
2
+
3
+ #include <algorithm>
4
+ #include <cstring>
5
+ #include <limits>
6
+ #include <queue>
7
+ #include <vector>
8
+
9
+ namespace ngram {
10
+
11
+ struct Node {
12
+ std::unordered_map<int32_t, int32_t> next;
13
+ };
14
+
15
+ Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
16
+ Ngram::Result info;
17
+ std::vector<int32_t> prevs;
18
+ info.token.reserve(draft_token_num);
19
+ prevs.reserve(draft_token_num);
20
+ std::queue<std::tuple<int32_t, int32_t, int32_t>> queue;
21
+ info.token.emplace_back(last_token);
22
+ prevs.emplace_back(-1);
23
+
24
+ for (auto [token, next] : tree[root].next) {
25
+ queue.emplace(token, next, 0);
26
+ }
27
+ while (queue.size()) {
28
+ auto [token, next, prev] = queue.front();
29
+ queue.pop();
30
+ info.token.emplace_back(token);
31
+ prevs.emplace_back(prev);
32
+ for (auto [t, n] : tree[next].next) {
33
+ queue.emplace(t, n, info.token.size() - 1);
34
+ }
35
+ }
36
+
37
+ // zero padding to length
38
+ while (info.token.size() < draft_token_num) {
39
+ info.token.emplace_back(0);
40
+ prevs.emplace_back(0);
41
+ }
42
+
43
+ int n = info.token.size();
44
+ info.mask.resize(n * n, 0);
45
+ info.mask[0] = 1;
46
+ for (int i = 0; i < n; ++i) {
47
+ if (prevs[i] != -1) {
48
+ memcpy(&info.mask[i * n], &info.mask[prevs[i] * n], prevs[i] + 1);
49
+ }
50
+ info.mask[i * n + i] = 1;
51
+ }
52
+
53
+ return info;
54
+ }
55
+
56
+ Ngram::Ngram(size_t capacity, const Param& param) {
57
+ param_ = param;
58
+ nodes_.resize(capacity);
59
+ for (auto& node : nodes_) {
60
+ node_pool_.emplace_back(&node);
61
+ }
62
+ free_node_count_ = node_pool_.size();
63
+ root_ = getNode();
64
+
65
+ if (!(param_.branch_length > 1)) {
66
+ throw std::runtime_error(
67
+ "param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length));
68
+ }
69
+ if (!(param_.min_match_window_size > 0)) {
70
+ throw std::runtime_error(
71
+ "min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size));
72
+ }
73
+ if (!(param_.min_match_window_size <= param_.max_match_window_size)) {
74
+ throw std::runtime_error(
75
+ "min_match_window_size must be less than or equal to max_match_window_size, current min_match_window_size: " +
76
+ std::to_string(param_.min_match_window_size) +
77
+ ", max_match_window_size: " + std::to_string(param_.max_match_window_size));
78
+ }
79
+ if (!(param_.max_match_window_size < param_.branch_length)) {
80
+ throw std::runtime_error(
81
+ "max_match_window_size must be less than branch_length, current max_match_window_size: " +
82
+ std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length));
83
+ }
84
+ if (!(param_.min_bfs_breadth > 0)) {
85
+ throw std::runtime_error(
86
+ "min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth));
87
+ }
88
+ if (!(param_.min_bfs_breadth <= param_.max_bfs_breadth)) {
89
+ throw std::runtime_error(
90
+ "min_bfs_breadth must be less than or equal to max_bfs_breadth, current min_bfs_breadth: " +
91
+ std::to_string(param_.min_bfs_breadth) + ", max_bfs_breadth: " + std::to_string(param_.max_bfs_breadth));
92
+ }
93
+ if (!(param_.draft_token_num > 0)) {
94
+ throw std::runtime_error(
95
+ "draft_token_num must be greater than 0, current value: " + std::to_string(param_.draft_token_num));
96
+ }
97
+ for (auto config : param_.batch_draft_token_num) {
98
+ if (config != std::numeric_limits<decltype(config)>::max()) {
99
+ if (!(config <= param_.draft_token_num)) {
100
+ throw std::runtime_error(
101
+ "batch_draft_token_num config value " + std::to_string(config) +
102
+ " must be less than or equal to draft_token_num: " + std::to_string(param_.draft_token_num));
103
+ }
104
+ }
105
+ }
106
+ for (auto config : param_.batch_min_match_window_size) {
107
+ if (config != std::numeric_limits<decltype(config)>::max()) {
108
+ if (!(config >= param_.min_match_window_size)) {
109
+ throw std::runtime_error(
110
+ "batch_min_match_window_size config value " + std::to_string(config) +
111
+ " must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size));
112
+ }
113
+ if (!(config <= param_.max_match_window_size)) {
114
+ throw std::runtime_error(
115
+ "batch_min_match_window_size config value " + std::to_string(config) +
116
+ " must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size));
117
+ }
118
+ }
119
+ }
120
+
121
+ quit_flag_ = false;
122
+ insert_worker_ = std::thread(&Ngram::insert, this);
123
+ }
124
+
125
+ Ngram::~Ngram() {
126
+ quit_flag_ = true;
127
+ insert_queue_.close();
128
+ insert_worker_.join();
129
+ }
130
+
131
+ std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
132
+ auto draft_token_num = param_.get_draft_token_num(batch_size);
133
+ auto min_match_window_size = param_.get_min_match_window_size(batch_size);
134
+ auto max_match_window_size = param_.max_match_window_size;
135
+ std::vector<std::pair<TrieNode*, int32_t>> result;
136
+ result.reserve(param_.max_match_window_size - param_.min_match_window_size);
137
+ for (int32_t match_window_size = std::min(tokens.size(), param_.max_match_window_size);
138
+ match_window_size >= param_.min_match_window_size;
139
+ --match_window_size) {
140
+ auto start = tokens.data() + tokens.size() - match_window_size;
141
+ auto end = start + match_window_size;
142
+ auto cursor = root_;
143
+ while (start != end) {
144
+ auto iter = cursor->child.find(*start);
145
+ if (iter == cursor->child.end()) {
146
+ cursor = nullptr;
147
+ break;
148
+ }
149
+ ++start;
150
+ cursor = iter->second;
151
+ }
152
+ if (cursor) {
153
+ result.emplace_back(std::make_pair(cursor, match_window_size));
154
+ }
155
+ }
156
+ return result;
157
+ }
158
+
159
+ void Ngram::squeeze(size_t count) {
160
+ if (!(node_pool_.size() >= free_node_count_ + count)) {
161
+ throw std::runtime_error(
162
+ "Insufficient node size to release required nodes. "
163
+ "available to release: " +
164
+ std::to_string(node_pool_.size() - free_node_count_) + ", required to release: " + std::to_string(count));
165
+ }
166
+ while (count--) {
167
+ auto last = global_lru_.back();
168
+ global_lru_.pop_back();
169
+
170
+ if (!last->child.empty()) {
171
+ throw std::runtime_error("The node to be released still has child nodes and cannot be released. ");
172
+ }
173
+
174
+ last->parent->lru.erase(last->parent_lru_pos);
175
+ last->parent->sorted_children.erase(last);
176
+ last->parent->child.erase(last->token);
177
+
178
+ node_pool_[free_node_count_++] = last;
179
+ }
180
+ }
181
+
182
+ void Ngram::synchronize() const {
183
+ while (!insert_queue_.empty()) {
184
+ std::this_thread::sleep_for(std::chrono::microseconds(10));
185
+ }
186
+ }
187
+
188
+ void Ngram::insert() {
189
+ while (!quit_flag_) {
190
+ std::vector<int32_t> data;
191
+ if (!insert_queue_.dequeue(data)) {
192
+ continue;
193
+ }
194
+ const auto* token = data.data();
195
+ size_t size = data.size();
196
+ std::unique_lock<std::mutex> lock(mutex_);
197
+
198
+ for (size_t i = 0; i + param_.min_match_window_size < size; ++i) {
199
+ auto start = token + i;
200
+ auto end = start + std::min(size - i, param_.branch_length);
201
+
202
+ if (end - start > free_node_count_) {
203
+ squeeze(end - start - free_node_count_);
204
+ }
205
+
206
+ TrieNode* cursor = root_;
207
+ path_.clear();
208
+ while (start != end) {
209
+ auto token = *start;
210
+ auto iter = cursor->child.find(token);
211
+ if (iter == cursor->child.end()) {
212
+ iter = cursor->child.insert({token, getNode()}).first;
213
+ auto node = iter->second;
214
+
215
+ cursor->lru.emplace_front(node);
216
+ global_lru_.emplace_back(node);
217
+
218
+ node->token = token;
219
+ node->parent = cursor;
220
+ node->parent_lru_pos = cursor->lru.begin();
221
+ node->global_lru_pos = --global_lru_.end();
222
+ node->freq = 1;
223
+ cursor->sorted_children.insert(node);
224
+ } else {
225
+ auto node = iter->second;
226
+ cursor->sorted_children.erase(node);
227
+ node->freq++;
228
+ cursor->sorted_children.insert(node);
229
+ cursor->lru.splice(cursor->lru.begin(), cursor->lru, node->parent_lru_pos);
230
+ }
231
+ cursor = iter->second;
232
+ path_.emplace_back(cursor);
233
+ ++start;
234
+ }
235
+
236
+ for (auto it = path_.rbegin(); it != path_.rend(); ++it) {
237
+ TrieNode* node = *it;
238
+ global_lru_.splice(global_lru_.begin(), global_lru_, node->global_lru_pos);
239
+ }
240
+ }
241
+ }
242
+ }
243
+
244
+ void Ngram::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
245
+ for (auto&& token : tokens) {
246
+ insert_queue_.enqueue(std::move(token));
247
+ }
248
+ }
249
+
250
+ Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
251
+ std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
252
+
253
+ double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
254
+ (param_.max_match_window_size - param_.min_match_window_size + 1);
255
+
256
+ auto draft_token_num = param_.get_draft_token_num(batch_size);
257
+ std::vector<Node> tree(draft_token_num + 1);
258
+ int root = 0;
259
+ int cursor = 1;
260
+
261
+ for (auto [node, depth] : nodes) {
262
+ std::queue<std::tuple<int32_t, double, const TrieNode*>> queue; // parent, bfs_breadth, node
263
+ queue.push({root, (param_.max_match_window_size - depth) * bfs_breadth_scale + param_.min_bfs_breadth, node});
264
+ while (queue.size() && cursor <= draft_token_num) {
265
+ auto front = queue.front();
266
+ queue.pop();
267
+
268
+ auto parent = std::get<0>(front);
269
+ auto cur_breadth = std::get<1>(front);
270
+ auto iter = std::get<2>(front)->lru.begin();
271
+
272
+ auto breadth = std::max(1, int32_t(cur_breadth));
273
+ for (int i = 0; i < breadth && iter != std::get<2>(front)->lru.end() && cursor <= draft_token_num; ++i, ++iter) {
274
+ auto token = (*iter)->token;
275
+ auto pos = -1;
276
+ if (auto tit = tree[parent].next.find(token); tit != tree[parent].next.end()) {
277
+ pos = tit->second;
278
+ } else {
279
+ pos = tree[parent].next.insert(std::make_pair(token, cursor++)).first->second;
280
+ }
281
+ queue.emplace(pos, cur_breadth - bfs_breadth_scale, *iter);
282
+ }
283
+ }
284
+ }
285
+
286
+ return fillResult(tokens.back(), draft_token_num + 1, tree, root);
287
+ }
288
+
289
+ Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
290
+ std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
291
+ auto draft_token_num = param_.get_draft_token_num(batch_size);
292
+
293
+ struct CompareByLastDouble {
294
+ bool operator()(
295
+ const std::tuple<double, const TrieNode*, double>& a, // parent_pos, node, final_prob
296
+ const std::tuple<double, const TrieNode*, double>& b) const {
297
+ return std::get<2>(a) < std::get<2>(b);
298
+ }
299
+ };
300
+
301
+ std::priority_queue<
302
+ std::tuple<double, const TrieNode*, double>,
303
+ std::vector<std::tuple<double, const TrieNode*, double>>,
304
+ CompareByLastDouble>
305
+ heap;
306
+
307
+ std::vector<Node> tree(draft_token_num + 1);
308
+
309
+ int root = 0;
310
+ int cursor = 1;
311
+ int top_k = param_.max_bfs_breadth;
312
+
313
+ auto addToHeap = [&heap, &top_k](int parent, const TrieNode* trie_node, double prob) -> void {
314
+ double sum_freq = 0.0;
315
+ int count = 0;
316
+ std::list<std::pair<TrieNode*, int32_t>> topk_children;
317
+ for (auto* child : trie_node->sorted_children) {
318
+ sum_freq += static_cast<double>(child->freq);
319
+ topk_children.emplace_back(child, child->freq);
320
+ if (++count >= top_k) break;
321
+ }
322
+ if (sum_freq <= 0) sum_freq = 1.0;
323
+ for (const auto& [child, freq] : topk_children) {
324
+ double norm_freq = static_cast<double>(freq) / sum_freq * prob;
325
+ heap.emplace(parent, child, norm_freq);
326
+ }
327
+ };
328
+
329
+ for (auto [node, _] : nodes) {
330
+ addToHeap(root, node, 1.0);
331
+
332
+ while (!heap.empty() && cursor <= draft_token_num) {
333
+ auto [parent, trie_node, prob] = heap.top(); // parent_pos, node, final_prob
334
+ heap.pop();
335
+ auto token = trie_node->token;
336
+ int pos = -1;
337
+ auto tit = tree[parent].next.find(token);
338
+ if (tit != tree[parent].next.end()) {
339
+ pos = tit->second;
340
+ } else {
341
+ pos = cursor++;
342
+ tree[parent].next[token] = pos;
343
+ }
344
+ addToHeap(pos, trie_node, prob);
345
+ }
346
+ }
347
+
348
+ return fillResult(tokens.back(), draft_token_num + 1, tree, root);
349
+ }
350
+
351
+ Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
352
+ std::unique_lock<std::mutex> lock(mutex_);
353
+ Result merged_result;
354
+ auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb;
355
+ for (const auto& tks : tokens) {
356
+ Result res = (this->*match_func)(tks, tokens.size());
357
+ merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
358
+ merged_result.mask.insert(merged_result.mask.end(), res.mask.begin(), res.mask.end());
359
+ }
360
+ return merged_result;
361
+ }
362
+
363
+ void Ngram::Result::truncate(size_t n) {
364
+ if (n < token.size()) {
365
+ int full_n = token.size();
366
+ for (int i = 1; i < n; ++i) {
367
+ memcpy(&mask[i * n], &mask[i * full_n], sizeof(mask[0]) * n);
368
+ }
369
+ token.resize(n);
370
+ mask.resize(n * n);
371
+ }
372
+ }
373
+
374
+ } // namespace ngram