sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,993 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Copyright 2023-2024 SGLang Team
5
+ Licensed under the Apache License, Version 2.0 (the "License");
6
+ you may not use this file except in compliance with the License.
7
+ You may obtain a copy of the License at
8
+
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+
11
+ Unless required by applicable law or agreed to in writing, software
12
+ distributed under the License is distributed on an "AS IS" BASIS,
13
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ See the License for the specific language governing permissions and
15
+ limitations under the License.
16
+ """
17
+
18
+ """
19
+ The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
20
+ """
21
+
22
+ import heapq
23
+ import time
24
+ from collections import defaultdict
25
+ from typing import TYPE_CHECKING, List, Optional, Tuple
26
+
27
+ import torch
28
+
29
+ from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
30
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
31
+ from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
32
+ from sglang.srt.mem_cache.radix_cache import (
33
+ RadixKey,
34
+ _key_match_page_size1,
35
+ get_child_key,
36
+ )
37
+
38
+ if TYPE_CHECKING:
39
+ from sglang.srt.managers.schedule_batch import Req
40
+
41
+ import logging
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ class TreeNode:
47
+
48
+ counter = 0
49
+
50
+ def __init__(self, id: Optional[int] = None):
51
+ self.children = defaultdict(TreeNode)
52
+ self.parent: TreeNode = None
53
+ self.key: RadixKey = None
54
+ self.value: Optional[torch.Tensor] = None
55
+ self.mamba_value: Optional[torch.Tensor] = None
56
+ # invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked;
57
+ # if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So,
58
+ # full_lock_ref is always >= mamba_lock_ref.
59
+ # for full_lock, once it is locked, its parent must be locked as well
60
+ # for mamba_lock, it only need lock node itself
61
+ self.full_lock_ref = 0
62
+ self.mamba_lock_ref = 0
63
+ # last access time is only used for sanity check. LRU is maintained by the lru list.
64
+ self.last_access_time = time.monotonic()
65
+
66
+ self.hit_count = 0
67
+ # store the host indices of KV cache
68
+ self.host_value = None
69
+
70
+ # for lru list, invariant:
71
+ # 1. prev has greater last_access_time
72
+ # 2. next has smaller last_access_time
73
+ self.prev = None
74
+ self.next = None
75
+ self.mamba_prev = None
76
+ self.mamba_next = None
77
+
78
+ self.id = TreeNode.counter if id is None else id
79
+ TreeNode.counter += 1
80
+
81
+ @property
82
+ def evicted(self):
83
+ return self.value is None
84
+
85
+ @property
86
+ def backuped(self):
87
+ return self.host_value is not None
88
+
89
+ def __lt__(self, other: "TreeNode"):
90
+ return self.last_access_time < other.last_access_time
91
+
92
+
93
+ class LRUList:
94
+ def __init__(self, mamba: bool = False):
95
+ self.mamba = mamba
96
+ if self.mamba:
97
+ self.prv = "mamba_prev"
98
+ self.nxt = "mamba_next"
99
+ self.lock_ref = "mamba_lock_ref"
100
+ else:
101
+ self.prv = "prev"
102
+ self.nxt = "next"
103
+ self.lock_ref = "full_lock_ref"
104
+ # Initialize dummy head and tail nodes
105
+ self.head = TreeNode() # Most recently used side
106
+ self.tail = TreeNode() # Least recently used side
107
+ setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail
108
+ setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head
109
+ self.cache = {}
110
+
111
+ def _add_node(self, node):
112
+ """Helper to add node right after head (most recently used)"""
113
+ self._add_node_after(self.head, node)
114
+
115
+ def _add_node_after(self, old_node, new_node):
116
+ """Helper to add node right after old_node"""
117
+ setattr(new_node, self.prv, old_node) # new_node.prev = old_node
118
+ setattr(
119
+ new_node, self.nxt, getattr(old_node, self.nxt)
120
+ ) # new_node.next = old_node.next
121
+ setattr(
122
+ getattr(old_node, self.nxt), self.prv, new_node
123
+ ) # old_node.next.prev = new_node
124
+ setattr(old_node, self.nxt, new_node) # old_node.next = new_node
125
+
126
+ def _remove_node(self, node):
127
+ """Helper to remove node from linked list"""
128
+ setattr(
129
+ getattr(node, self.prv), self.nxt, getattr(node, self.nxt)
130
+ ) # node.prev.next = node.next
131
+ setattr(
132
+ getattr(node, self.nxt), self.prv, getattr(node, self.prv)
133
+ ) # node.next.prev = node.prev
134
+
135
+ def _get_lru(self) -> Optional[TreeNode]:
136
+ """
137
+ Get the least recently used node
138
+ """
139
+ if len(self.cache) == 0:
140
+ return None
141
+ return getattr(self.tail, self.prv)
142
+
143
+ def reset_node_mru(self, node):
144
+ """
145
+ Move a (existing) node to most recently used position
146
+ """
147
+ assert node.id in self.cache, f"Resetting node {node.id=} not in lru list"
148
+ assert (
149
+ not self.mamba or node.mamba_value is not None
150
+ ), f"Resetting mamba tombstone node in mamba lru list: {node.id=}"
151
+ self._remove_node(node)
152
+ self._add_node(node)
153
+
154
+ def reset_node_and_parents_mru(self, node, root_node):
155
+ """
156
+ Move an (existing) node and its parents to most recently used position. Child node is
157
+ more recently used than parent node.
158
+ """
159
+ prev_node = self.head
160
+ while node != root_node:
161
+ if not self.mamba or node.mamba_value is not None:
162
+ assert (
163
+ node.id in self.cache
164
+ ), f"Resetting node {node.id=} not in lru list when resetting node and parents mru"
165
+ self._remove_node(node)
166
+ self._add_node_after(prev_node, node)
167
+ prev_node = node
168
+ node = node.parent
169
+
170
+ def insert_mru(self, node):
171
+ """
172
+ Insert a (new) node as most recently used
173
+ """
174
+ assert (
175
+ not self.mamba or node.mamba_value is not None
176
+ ), f"Inserting mamba tombstone node in mamba lru list: {node.id=}"
177
+ assert (
178
+ node.id not in self.cache
179
+ ), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}"
180
+ self.cache[node.id] = node
181
+ self._add_node(node)
182
+
183
+ def remove_node(self, node: TreeNode):
184
+ """
185
+ Remove node from lru list
186
+ """
187
+ assert node.id in self.cache, f"Removing node {node.id=} not in lru list"
188
+ assert (
189
+ not self.mamba or node.mamba_value is not None
190
+ ), f"Removing mamba tombstone node from mamba lru list: {node.id=}"
191
+ del self.cache[node.id]
192
+ self._remove_node(node)
193
+
194
+ def get_lru_no_lock(self) -> Optional[TreeNode]:
195
+ """
196
+ Get the least recently used node that is not locked
197
+ """
198
+ return self.get_prev_no_lock(self.tail, check_id=False)
199
+
200
+ def get_leaf_lru_no_lock(self) -> Optional[TreeNode]:
201
+ """
202
+ Get the least recently used leaf node that is not locked
203
+ """
204
+ return self.get_prev_leaf_no_lock(self.tail, check_id=False)
205
+
206
+ def get_prev_no_lock(
207
+ self, node: TreeNode, check_id: bool = True
208
+ ) -> Optional[TreeNode]:
209
+ """
210
+ Get the previous (i.e. more recently used) node that is not locked
211
+ """
212
+ if check_id:
213
+ assert (
214
+ node.id in self.cache
215
+ ), f"Getting prev of node {node.id=} not in lru list"
216
+ x = getattr(node, self.prv) # x = node.prev
217
+ while getattr(x, self.lock_ref) > 0:
218
+ x = getattr(x, self.prv) # x = x.prev
219
+ # if x is the head, it means there is no node in the lru list without lock
220
+ if x == self.head:
221
+ return None
222
+ return x
223
+
224
+ def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True):
225
+ """
226
+ Get the previous (i.e. more recently used) leaf node that is not locked
227
+ """
228
+ if check_id:
229
+ assert (
230
+ node.id in self.cache
231
+ ), f"Getting prev of node {node.id=} not in lru list"
232
+ x = getattr(node, self.prv) # x = node.prev
233
+ while getattr(x, self.lock_ref) > 0 or len(x.children) > 0:
234
+ x = getattr(x, self.prv) # x = x.prev
235
+ # if x is the head, it means there is no leaf node in the lru list without lock
236
+ if x == self.head:
237
+ return None
238
+ return x
239
+
240
+ def in_list(self, node: Optional[TreeNode]):
241
+ """
242
+ Check if the node is in the lru list
243
+ """
244
+ if not node:
245
+ return False
246
+ return node.id in self.cache
247
+
248
+ # Note: this is expensive, only use for debug
249
+ def sanity_check_evictable_size(self):
250
+ """
251
+ Check the evictable size (i.e. the size of the nodes that are not locked)
252
+ """
253
+ node = self.get_lru_no_lock()
254
+ evictable_size = 0
255
+ while self.in_list(node):
256
+ evictable_size += (
257
+ len(node.value) if not self.mamba else len(node.mamba_value)
258
+ )
259
+ node = self.get_prev_no_lock(node)
260
+ return evictable_size
261
+
262
+ # Note: this is expensive, only use for debug or idle check
263
+ def sanity_check(self, tree_cache: "MambaRadixCache"):
264
+ """
265
+ Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and
266
+ checking if the lru list is valid.
267
+ """
268
+ try:
269
+ if self.mamba:
270
+ nodes = tree_cache._collect_nontombstone_nodes()
271
+ else:
272
+ nodes = tree_cache._collect_all_nodes()
273
+ total_nodes = len(nodes)
274
+ total_lru = len(self.cache)
275
+ # heapify based on last_access_time
276
+ heapq.heapify(nodes)
277
+ # the root node is not in the lru list
278
+ assert len(nodes) == (
279
+ total_lru + (0 if self.mamba else 1)
280
+ ), f"len(nodes): {len(nodes)}, total_lru: {total_lru}"
281
+
282
+ x_lru = self._get_lru()
283
+ while len(nodes):
284
+ x = heapq.heappop(nodes)
285
+ if x == tree_cache.root_node:
286
+ # root node is not in the lru list
287
+ continue
288
+ assert (
289
+ x == x_lru
290
+ ), f"Incorrect LRU list, {self.mamba=}, x: {x.id=} != x_lru: {x_lru.id=}"
291
+ assert (
292
+ x_lru.full_lock_ref == 0
293
+ ), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.id=}"
294
+ assert (
295
+ x_lru.mamba_lock_ref == 0
296
+ ), f"x_lru should not be locked when idle, {x_lru.mamba_lock_ref=}, {x_lru.id=}"
297
+ x_lru = getattr(x, self.prv)
298
+
299
+ if self.mamba:
300
+ evictable_size = tree_cache.mamba_evictable_size()
301
+ lru_list_evictable_size = tree_cache.mamba_lru_list_evictable_size()
302
+ else:
303
+ evictable_size = tree_cache.full_evictable_size()
304
+ lru_list_evictable_size = tree_cache.full_lru_list_evictable_size()
305
+
306
+ assert (
307
+ evictable_size == lru_list_evictable_size
308
+ ), f"{self.mamba=}, total nodes: {total_nodes}, total lru: {total_lru}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}"
309
+ except Exception as e:
310
+ msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}"
311
+ logger.error(msg)
312
+ raise Exception(msg)
313
+
314
+
315
+ class MambaRadixCache(BasePrefixCache):
316
+ def __init__(
317
+ self,
318
+ req_to_token_pool: HybridReqToTokenPool,
319
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
320
+ page_size: int,
321
+ disable: bool = False,
322
+ ):
323
+ assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator)
324
+ self.req_to_token_pool = req_to_token_pool
325
+ self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
326
+
327
+ assert page_size == 1, "Only support page_size=1 in mamba radix cache now."
328
+ self.page_size = page_size
329
+ self.disable = disable
330
+
331
+ if self.token_to_kv_pool_allocator:
332
+ self.device = self.token_to_kv_pool_allocator.device
333
+ else:
334
+ self.device = torch.device("cpu")
335
+
336
+ self.key_match_fn = _key_match_page_size1
337
+ self.get_child_key_fn = get_child_key
338
+ self.reset()
339
+
340
+ ##### Public API #####
341
+
342
+ def reset(self) -> None:
343
+ self.root_node = TreeNode()
344
+ self.root_node.key = []
345
+ self.root_node.value = []
346
+ self.root_node.full_lock_ref = 1
347
+ self.root_node.mamba_lock_ref = 1
348
+ self.full_evictable_size_ = 0
349
+ self.mamba_evictable_size_ = 0
350
+ self.full_protected_size_ = 0
351
+ self.mamba_protected_size_ = 0
352
+ # LRU lists are used to maintain the order of eviction of the nodes in the tree
353
+ self.full_lru_list = LRUList(mamba=False)
354
+ self.mamba_lru_list = LRUList(mamba=True)
355
+
356
+ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
357
+ """Find the matching prefix from the radix tree.
358
+ Args:
359
+ key: A RadixKey contains token IDs to find a matching prefix.
360
+ Returns:
361
+ A tuple of a tensor of matching prefix token IDs and
362
+ the last node that contains the prefix values. Note that
363
+ this API can modify the internal state of the Radix tree.
364
+ The last node create a new child if the prefix is shorter
365
+ than the last node's value.
366
+ """
367
+ cow_mamba: bool = kwargs.get("cow_mamba", False)
368
+ req: Req = kwargs.get("req", None)
369
+
370
+ if self.disable or len(key) == 0:
371
+ return MatchResult(
372
+ device_indices=torch.empty(
373
+ (0,),
374
+ dtype=torch.int64,
375
+ device=self.device,
376
+ ),
377
+ last_device_node=self.root_node,
378
+ last_host_node=self.root_node,
379
+ )
380
+
381
+ value, last_node = self._match_prefix_helper(key)
382
+
383
+ # copy mamba state to req local space if cow is true
384
+ if cow_mamba and last_node.mamba_value is not None:
385
+ assert req.req_pool_idx is None # req_pool_idx is uninitialed
386
+
387
+ # for reqs without mamba cache
388
+ if req.mamba_pool_idx is None:
389
+ dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
390
+ # try to alloc again, protect last_node from eviction
391
+ if dst_index is None:
392
+ self.inc_lock_ref(last_node)
393
+ self.evict_mamba(1)
394
+ dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
395
+ self.dec_lock_ref(last_node)
396
+ assert dst_index is not None, "Can not alloc mamba cache"
397
+ src_index = last_node.mamba_value
398
+ self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
399
+ req.mamba_pool_idx = dst_index[0]
400
+ else:
401
+ src_index = last_node.mamba_value
402
+ dst_index = req.mamba_pool_idx.unsqueeze(0)
403
+ self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
404
+
405
+ if value:
406
+ value = torch.cat(value)
407
+ else:
408
+ value = torch.empty((0,), dtype=torch.int64, device=self.device)
409
+
410
+ return MatchResult(
411
+ device_indices=value,
412
+ last_device_node=last_node,
413
+ last_host_node=last_node,
414
+ )
415
+
416
+ def insert(self, key: RadixKey, value=None, mamba_value=None) -> Tuple[int, bool]:
417
+ if self.disable:
418
+ return 0
419
+
420
+ if value is None:
421
+ value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
422
+ return self._insert_helper(self.root_node, key, value, mamba_value)
423
+
424
+ def cache_finished_req(self, req: Req) -> None:
425
+ """Cache request when it finishes."""
426
+ if self.disable:
427
+ kv_indices = self.req_to_token_pool.req_to_token[
428
+ req.req_pool_idx,
429
+ : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
430
+ ]
431
+ self.token_to_kv_pool_allocator.free(kv_indices)
432
+ self.req_to_token_pool.free(req.req_pool_idx)
433
+ return
434
+
435
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
436
+ kv_indices = self.req_to_token_pool.req_to_token[
437
+ req.req_pool_idx, : len(token_ids)
438
+ ]
439
+
440
+ page_aligned_len = len(kv_indices)
441
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
442
+
443
+ # Radix Cache takes one ref in memory pool
444
+ # insert the token_ids and kv_indices into the radix tree
445
+ # Note: the insert function already frees the overlapped kv_indices
446
+ mamba_value = (
447
+ self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
448
+ .unsqueeze(-1)
449
+ .clone()
450
+ )
451
+
452
+ new_prefix_len, mamba_exist = self.insert(
453
+ RadixKey(token_ids[:page_aligned_len], req.extra_key),
454
+ page_aligned_kv_indices,
455
+ mamba_value,
456
+ )
457
+
458
+ self.token_to_kv_pool_allocator.free(
459
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
460
+ )
461
+
462
+ self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
463
+ self.dec_lock_ref(req.last_node)
464
+
465
+ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
466
+ """Cache request when it is unfinished."""
467
+ if self.disable:
468
+ kv_indices = self.req_to_token_pool.req_to_token[
469
+ req.req_pool_idx, : len(req.fill_ids)
470
+ ]
471
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
472
+ req.prefix_indices = kv_indices
473
+ return
474
+
475
+ token_ids = req.fill_ids
476
+ kv_indices = self.req_to_token_pool.req_to_token[
477
+ req.req_pool_idx, : len(token_ids)
478
+ ]
479
+ page_aligned_len = len(kv_indices)
480
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
481
+ page_aligned_token_ids = token_ids[:page_aligned_len]
482
+
483
+ mamba_value = self.req_to_token_pool.get_mamba_indices(
484
+ req.req_pool_idx
485
+ ).unsqueeze(-1)
486
+ # radix tree mamba value is forked from req space
487
+ mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)
488
+
489
+ # if alloc mamba cache failed, do evict and alloc again
490
+ if mamba_value_forked is None:
491
+ self.evict_mamba(1)
492
+ mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(
493
+ mamba_value
494
+ )
495
+ assert mamba_value_forked is not None, "Can not alloc mamba cache"
496
+ new_prefix_len, mamba_exist = self.insert(
497
+ RadixKey(page_aligned_token_ids, req.extra_key),
498
+ page_aligned_kv_indices,
499
+ mamba_value_forked,
500
+ )
501
+ self.token_to_kv_pool_allocator.free(
502
+ kv_indices[len(req.prefix_indices) : new_prefix_len]
503
+ )
504
+ # there is a mamba cache in radix cache, release it
505
+ if mamba_exist:
506
+ self.req_to_token_pool.mamba_pool.free(mamba_value_forked)
507
+
508
+ # The prefix indices could be updated, reuse it
509
+ new_indices, new_last_node, _, _ = self.match_prefix(
510
+ RadixKey(page_aligned_token_ids, req.extra_key)
511
+ )
512
+
513
+ if not mamba_exist:
514
+ assert torch.equal(new_last_node.mamba_value, mamba_value_forked)
515
+
516
+ assert len(req.prefix_indices) <= len(
517
+ new_indices
518
+ ), f"{req.prefix_indices=}, {new_indices=}"
519
+ assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
520
+
521
+ self.req_to_token_pool.write(
522
+ (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
523
+ new_indices[len(req.prefix_indices) :],
524
+ )
525
+
526
+ self.dec_lock_ref(req.last_node)
527
+ self.inc_lock_ref(new_last_node)
528
+
529
+ # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
530
+ req.prefix_indices = new_indices
531
+ req.last_node = new_last_node
532
+
533
+ def pretty_print(self) -> None:
534
+ self._print_helper(self.root_node, 0)
535
+ total_size, total_mamba_size = self._total_size_helper()
536
+ print(f"#full_tokens: {total_size}, #mamba_num: {total_mamba_size}")
537
+
538
+ def total_size(self) -> Tuple[int, int]:
539
+ return self._total_size_helper()
540
+
541
+ def _evict_leaf_node(
542
+ self, x: TreeNode, is_evict_mamba: bool
543
+ ) -> Tuple[int, int, TreeNode, TreeNode]:
544
+ assert (
545
+ x.full_lock_ref == 0 and x.mamba_lock_ref == 0
546
+ ), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}"
547
+
548
+ assert x.mamba_value is not None, f"leaf node mamba value is not None, {x.id=}"
549
+ # 1. a leaf node, free full tokens and mamba
550
+ self.token_to_kv_pool_allocator.free(x.value)
551
+ full_num_evicted = len(x.value)
552
+ self.req_to_token_pool.mamba_pool.free(x.mamba_value)
553
+ mamba_num_evicted = len(x.mamba_value)
554
+
555
+ # 2. get the next node, update the lru lists
556
+ if is_evict_mamba:
557
+ x_next = self.mamba_lru_list.get_prev_no_lock(x)
558
+ else:
559
+ x_next = self.full_lru_list.get_prev_leaf_no_lock(x)
560
+ self.full_lru_list.remove_node(x)
561
+ self.mamba_lru_list.remove_node(x)
562
+
563
+ # 3. delete the leaf node
564
+ self._delete_leaf(x)
565
+
566
+ # 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
567
+ x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x)
568
+ full_num_evicted += leaf_full_num_evicted
569
+ return full_num_evicted, mamba_num_evicted, x, x_next
570
+
571
+ def evict_mamba(self, mamba_num: int) -> None:
572
+ if self.disable or mamba_num <= 0:
573
+ return
574
+ # get the least recently used node that is not locked, doesn't have to be a leaf
575
+ x = self.mamba_lru_list.get_lru_no_lock()
576
+ mamba_num_evicted = 0
577
+ # evict lru leaf nodes until mamba_num_tokens is reached
578
+ while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)):
579
+ assert x.mamba_value is not None, f"node has no mamba value, {x.id=}"
580
+ assert (
581
+ len(x.mamba_value) == 1
582
+ ), f"node has abnormal mamba length, {x.id=}, {len(x.mamba_value)=}"
583
+ assert x != self.root_node, f"root node is not evictable, {x.id=}"
584
+ assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}"
585
+
586
+ if len(x.children) > 0:
587
+ # 1. an internal node, free mamba tokens.
588
+ self.req_to_token_pool.mamba_pool.free(x.mamba_value)
589
+ mamba_num_evicted += len(x.mamba_value)
590
+
591
+ # 2. get the next node, update the lru lists
592
+ x_next = self.mamba_lru_list.get_prev_no_lock(x)
593
+ self.mamba_lru_list.remove_node(x)
594
+
595
+ # 3. tombstone the node
596
+ self._tombstone_internal_node(x)
597
+ else:
598
+ _, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True)
599
+ mamba_num_evicted += mamba_evicted_delta
600
+
601
+ x = x_next
602
+
603
+ def evict(self, full_num_tokens: int) -> None:
604
+ if self.disable or full_num_tokens <= 0:
605
+ return
606
+
607
+ full_num_evicted = 0
608
+ # get the least recently used leaf node that is not locked
609
+ x = self.full_lru_list.get_leaf_lru_no_lock()
610
+
611
+ while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
612
+ assert (
613
+ x != self.root_node
614
+ ), f"root node should not exist in full lru list, {x.id=}"
615
+ full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False)
616
+ full_num_evicted += full_num_evicted_delta
617
+
618
+ # if parent has no more children, it is a leaf. It is possible that this node is lru, so
619
+ # we need to get the first leaf node in the lru list
620
+ if len(x.parent.children) == 0:
621
+ x_next = self.full_lru_list.get_leaf_lru_no_lock()
622
+
623
+ x = x_next
624
+
625
+ def inc_lock_ref(self, node: TreeNode) -> Optional[int]:
626
+ """
627
+ Increment the lock reference count for the node.
628
+ It locks the full_lock_ref for nodes between the [last node, root), exclusive.
629
+ It locks the mamba_lock_ref for current node if its mamba_value exists.
630
+ """
631
+ if self.disable:
632
+ return None
633
+
634
+ # protect mamba value in current node if it exists
635
+ if node.mamba_value is not None:
636
+ if node.mamba_lock_ref == 0:
637
+ self.mamba_evictable_size_ -= len(node.mamba_value)
638
+ self.mamba_protected_size_ += len(node.mamba_value)
639
+ node.mamba_lock_ref += 1
640
+
641
+ while node != self.root_node:
642
+ # lock full from node to root
643
+ assert (
644
+ node.full_lock_ref >= 0
645
+ ), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
646
+ if node.full_lock_ref == 0:
647
+ self.full_evictable_size_ -= len(node.value)
648
+ self.full_protected_size_ += len(node.value)
649
+ node.full_lock_ref += 1
650
+ node = node.parent
651
+ return None
652
+
653
+ def dec_lock_ref(self, node: TreeNode):
654
+ """
655
+ Decrement the lock reference count for the node.
656
+ It unlocks the full_lock_ref for nodes between the [last node, root), exclusive.
657
+ It unlocks the mamba_lock_ref for current node if its mamba_value exists.
658
+ """
659
+ if self.disable:
660
+ return
661
+
662
+ if node.mamba_value is not None:
663
+ assert (
664
+ node.mamba_lock_ref > 0
665
+ ), f"dec_lock_ref on node with {node.mamba_lock_ref=}, {node.id=}"
666
+ if node.mamba_lock_ref == 1:
667
+ self.mamba_evictable_size_ += len(node.mamba_value)
668
+ self.mamba_protected_size_ -= len(node.mamba_value)
669
+ node.mamba_lock_ref -= 1
670
+
671
+ while node != self.root_node:
672
+ assert (
673
+ node.full_lock_ref > 0
674
+ ), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
675
+ if node.full_lock_ref == 1:
676
+ self.full_evictable_size_ += len(node.value)
677
+ self.full_protected_size_ -= len(node.value)
678
+ node.full_lock_ref -= 1
679
+ node = node.parent
680
+
681
+ def sanity_check(self):
682
+ self.full_lru_list.sanity_check(self)
683
+ self.mamba_lru_list.sanity_check(self)
684
+
685
+ def evictable_size(self) -> Tuple[int, int]:
686
+ # Note: use full_evictable_size() and mamba_evictable_size() instead.
687
+ raise NotImplementedError
688
+
689
+ def full_evictable_size(self) -> int:
690
+ return self.full_evictable_size_
691
+
692
+ def mamba_evictable_size(self) -> int:
693
+ return self.mamba_evictable_size_
694
+
695
+ # Note: this is expensive, only use for debug
696
+ def full_lru_list_evictable_size(self) -> int:
697
+ return self.full_lru_list.sanity_check_evictable_size()
698
+
699
+ # Note: this is expensive, only use for debug
700
+ def mamba_lru_list_evictable_size(self) -> int:
701
+ return self.mamba_lru_list.sanity_check_evictable_size()
702
+
703
+ def protected_size(self) -> Tuple[int, int]:
704
+ # Note: use full_protected_size() and mamba_protected_size() instead.
705
+ raise NotImplementedError
706
+
707
+ def full_protected_size(self) -> int:
708
+ # protected size refers to the size of the full cache that is locked
709
+ return self.full_protected_size_
710
+
711
+ def mamba_protected_size(self) -> int:
712
+ # protected size refers to the size of the mamba cache that is locked
713
+ return self.mamba_protected_size_
714
+
715
+ def all_values_flatten(self) -> torch.Tensor:
716
+ values = []
717
+
718
+ def _dfs_helper(node: TreeNode):
719
+ for _, child in node.children.items():
720
+ values.append(child.value)
721
+ _dfs_helper(child)
722
+
723
+ _dfs_helper(self.root_node)
724
+ return torch.cat(values)
725
+
726
+ ##### Internal Helper Functions #####
727
+
728
+ def _match_prefix_helper(
729
+ self, key: RadixKey
730
+ ) -> Tuple[List[torch.Tensor], TreeNode]:
731
+ """
732
+ Mamba prefix matching helper. It factors in the sliding window size such that
733
+ the matched node is guaranteed to either 1. connected to root without mamba tombstone,
734
+ or 2. the number of matching tokens from the matched node to the last mamba tombstone
735
+ node is greater than or equal to the sliding window size.
736
+ """
737
+ node = self.root_node
738
+ child_key = self.get_child_key_fn(key)
739
+
740
+ value = []
741
+ best_value_len = 0
742
+ best_last_node = node
743
+ while len(key) > 0 and child_key in node.children.keys():
744
+ child = node.children[child_key]
745
+ # update best_value_len and best_last_node if needed
746
+ if node.mamba_value is not None:
747
+ best_value_len = len(value)
748
+ best_last_node = node
749
+
750
+ prefix_len = self.key_match_fn(child.key, key)
751
+ if prefix_len < len(child.key):
752
+ new_node = self._split_node(child.key, child, prefix_len)
753
+ value.append(new_node.value)
754
+ node = new_node
755
+ break
756
+ else:
757
+ value.append(child.value)
758
+ node = child
759
+ key = key[prefix_len:]
760
+
761
+ if len(key):
762
+ child_key = self.get_child_key_fn(key)
763
+ # handle best_value_len and best_last_node, for the case that last node is fully matched
764
+ if node.mamba_value is not None:
765
+ best_value_len = len(value)
766
+ best_last_node = node
767
+
768
+ # update time for matched nodes, and make nodes closer to root to be least recently used
769
+ # this allows mamba to evict nodes closer to root first
770
+ self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
771
+ self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
772
+
773
+ # This last_access_time is for sanity check, can be deleted after validation in production
774
+ cur_time = time.monotonic()
775
+ while node:
776
+ node.last_access_time = cur_time
777
+ cur_time -= 0.0001
778
+ node = node.parent
779
+
780
+ return value[:best_value_len], best_last_node
781
+
782
+ def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
783
+ # new_node -> child
784
+ new_node = TreeNode()
785
+ new_node.children = {self.get_child_key_fn(key[split_len:]): child}
786
+ new_node.parent = child.parent
787
+ new_node.mamba_value = None # mamba cache can not be split
788
+ new_node.full_lock_ref = child.full_lock_ref
789
+ new_node.mamba_lock_ref = 0
790
+ new_node.key = child.key[:split_len]
791
+ new_node.value = child.value[:split_len]
792
+
793
+ # child time should be later than parent's time for mamba tombstone
794
+ child.last_access_time = time.monotonic()
795
+
796
+ self.full_lru_list.remove_node(child)
797
+ if child.mamba_value is not None:
798
+ self.mamba_lru_list.remove_node(child)
799
+ child.parent = new_node
800
+ child.key = child.key[split_len:]
801
+ child.value = child.value[split_len:]
802
+ new_node.parent.children[self.get_child_key_fn(key)] = new_node
803
+
804
+ # insert the new node and child into the lru lists, insert
805
+ # parent first so that parent is after child in the lru list
806
+ self.full_lru_list.insert_mru(new_node)
807
+ self.full_lru_list.insert_mru(child)
808
+ if child.mamba_value is not None:
809
+ self.mamba_lru_list.insert_mru(child)
810
+ return new_node
811
+
812
+ def _insert_helper(
813
+ self,
814
+ node: TreeNode,
815
+ key: RadixKey,
816
+ value,
817
+ mamba_value,
818
+ ) -> Tuple[int, bool]:
819
+ # Update the last access time from root to leaf, so that
820
+ # mamba will tombstone the node closer to root first
821
+ assert mamba_value is not None, "Mamba value should not be None here."
822
+ node.last_access_time = time.monotonic()
823
+ if node != self.root_node:
824
+ self.full_lru_list.reset_node_mru(node)
825
+ if node.mamba_value is not None:
826
+ self.mamba_lru_list.reset_node_mru(node)
827
+ if len(key) == 0:
828
+ return 0, True
829
+
830
+ child_key = self.get_child_key_fn(key)
831
+
832
+ total_prefix_length = 0
833
+ while len(key) > 0 and child_key in node.children.keys():
834
+ node = node.children[child_key]
835
+ node.last_access_time = time.monotonic()
836
+ self.full_lru_list.reset_node_mru(node)
837
+ if node.mamba_value is not None:
838
+ self.mamba_lru_list.reset_node_mru(node)
839
+ prefix_len = self.key_match_fn(node.key, key)
840
+ total_prefix_length += prefix_len
841
+ key = key[prefix_len:]
842
+ value = value[prefix_len:]
843
+
844
+ if prefix_len < len(node.key):
845
+ new_node = self._split_node(node.key, node, prefix_len)
846
+ node = new_node
847
+
848
+ if len(key):
849
+ child_key = self.get_child_key_fn(key)
850
+
851
+ mamba_value_exist = False
852
+ if len(key):
853
+ new_node = TreeNode()
854
+ new_node.parent = node
855
+ new_node.key = key
856
+ new_node.value = value
857
+ new_node.mamba_value = mamba_value
858
+ self.full_lru_list.insert_mru(new_node)
859
+ self.full_evictable_size_ += len(value)
860
+ self.mamba_evictable_size_ += len(mamba_value)
861
+ self.mamba_lru_list.insert_mru(new_node)
862
+ node.children[child_key] = new_node
863
+ elif node.mamba_value is None: # add for mamba tombstone
864
+ node.mamba_value = mamba_value
865
+ self.mamba_evictable_size_ += len(mamba_value)
866
+ self.mamba_lru_list.insert_mru(node)
867
+ else:
868
+ mamba_value_exist = True
869
+ self.mamba_lru_list.reset_node_mru(node)
870
+
871
+ return total_prefix_length, mamba_value_exist
872
+
873
+ def _iteratively_delete_tombstone_leaf(
874
+ self, node: TreeNode
875
+ ) -> Tuple[TreeNode, int]:
876
+ full_num_evicted = 0
877
+ while node.parent.mamba_value is None and len(node.parent.children) == 0:
878
+ # root node is not evictable
879
+ if node.parent == self.root_node:
880
+ break
881
+ # if locked, means node is in use, skip
882
+ if node.parent.full_lock_ref > 0:
883
+ break
884
+ assert (
885
+ node.parent.mamba_lock_ref == 0
886
+ ), f"tombstone mamba_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.mamba_lock_ref=}, {node.parent.id=}"
887
+ # delete tombstone node evicts full tokens
888
+ self.token_to_kv_pool_allocator.free(node.parent.value)
889
+ full_num_evicted += len(node.parent.value)
890
+ self.full_lru_list.remove_node(node.parent)
891
+ self._delete_tombstone_leaf(node.parent)
892
+ node = node.parent
893
+
894
+ return node, full_num_evicted
895
+
896
+ def _delete_leaf(self, node: TreeNode) -> None:
897
+ assert (
898
+ node.mamba_value is not None
899
+ ), f"Invariant violated: leaf node is a tombstone, {node.id=}"
900
+ assert len(node.children) == 0, f"leaf node has children, {node.id=}"
901
+ for k, v in node.parent.children.items():
902
+ if v == node:
903
+ break
904
+ del node.parent.children[k]
905
+ self.full_evictable_size_ -= len(node.key)
906
+ self.mamba_evictable_size_ -= len(node.mamba_value)
907
+
908
+ def _tombstone_internal_node(self, node: TreeNode) -> None:
909
+ assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}"
910
+ self.mamba_evictable_size_ -= len(node.mamba_value)
911
+ node.mamba_value = None
912
+
913
+ def _delete_tombstone_leaf(self, node: TreeNode) -> None:
914
+ assert (
915
+ node.mamba_value is None
916
+ ), f"Deleting a unexpected non-tombstone leaf node, {node.id=}"
917
+ assert len(node.children) == 0, f"leaf node has children, {node.id=}"
918
+ for k, v in node.parent.children.items():
919
+ if v == node:
920
+ break
921
+ del node.parent.children[k]
922
+ self.full_evictable_size_ -= len(node.key)
923
+
924
+ def _collect_leaves(self) -> List[TreeNode]:
925
+ ret_list = []
926
+ stack = [self.root_node]
927
+
928
+ while stack:
929
+ cur_node = stack.pop()
930
+ if len(cur_node.children) == 0:
931
+ ret_list.append(cur_node)
932
+ else:
933
+ stack.extend(cur_node.children.values())
934
+
935
+ return ret_list
936
+
937
+ def _collect_nontombstone_nodes(self) -> List[TreeNode]:
938
+ ret_list = []
939
+ stack = [self.root_node]
940
+
941
+ while stack:
942
+ cur_node = stack.pop()
943
+ if cur_node.mamba_value is not None:
944
+ ret_list.append(cur_node)
945
+ stack.extend(cur_node.children.values())
946
+
947
+ return ret_list
948
+
949
+ def _collect_all_nodes(self) -> List[TreeNode]:
950
+ ret_list = []
951
+ stack = [self.root_node]
952
+ while stack:
953
+ cur_node = stack.pop()
954
+ ret_list.append(cur_node)
955
+ stack.extend(cur_node.children.values())
956
+ return ret_list
957
+
958
+ def _print_helper(self, node: TreeNode, indent: int) -> None:
959
+ """Prints the radix tree in a human-readable format."""
960
+ stack = [(node, indent)]
961
+ while stack:
962
+ current_node, current_indent = stack.pop()
963
+ print(
964
+ " " * current_indent,
965
+ f"[{current_node.id}]",
966
+ len(current_node.key),
967
+ f"fr={current_node.full_lock_ref}",
968
+ f"mr={current_node.mamba_lock_ref}",
969
+ f"fll={self.full_lru_list.in_list(current_node)}",
970
+ f"mll={self.mamba_lru_list.in_list(current_node)}",
971
+ f"mv={current_node.mamba_value}",
972
+ )
973
+ for key, child in current_node.children.items():
974
+ stack.append((child, current_indent + 2))
975
+
976
+ assert key == self.get_child_key_fn(
977
+ child.key
978
+ ), f"{key=}, {self.get_child_key_fn(child.key)=}"
979
+
980
+ def _total_size_helper(self) -> Tuple[int, int]:
981
+ total_size = 0
982
+ total_mamba_size = 0
983
+ stack = [self.root_node]
984
+ while stack:
985
+ current_node = stack.pop()
986
+ total_size += len(current_node.value)
987
+ if current_node.mamba_value is not None:
988
+ total_mamba_size += len(current_node.mamba_value)
989
+ for child in current_node.children.values():
990
+ if child.evicted:
991
+ continue
992
+ stack.append(child)
993
+ return total_size, total_mamba_size