sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -22,13 +22,9 @@ import torch
22
22
 
23
23
  from sglang.srt.configs.model_config import ModelConfig
24
24
  from sglang.srt.distributed import get_pp_group, get_world_group
25
- from sglang.srt.hf_transformers_utils import (
26
- get_processor,
27
- get_tokenizer,
28
- get_tokenizer_from_processor,
29
- )
30
25
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
31
26
  from sglang.srt.managers.io_struct import (
27
+ DestroyWeightsUpdateGroupReqInput,
32
28
  GetWeightsByNameReqInput,
33
29
  InitWeightsSendGroupForRemoteInstanceReqInput,
34
30
  InitWeightsUpdateGroupReqInput,
@@ -42,11 +38,20 @@ from sglang.srt.managers.io_struct import (
42
38
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
43
39
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
44
40
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
45
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
41
+ from sglang.srt.model_executor.forward_batch_info import (
42
+ ForwardBatch,
43
+ ForwardBatchOutput,
44
+ PPProxyTensors,
45
+ )
46
46
  from sglang.srt.model_executor.model_runner import ModelRunner
47
- from sglang.srt.patch_torch import monkey_patch_torch_reductions
48
47
  from sglang.srt.server_args import ServerArgs
49
48
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
49
+ from sglang.srt.utils.hf_transformers_utils import (
50
+ get_processor,
51
+ get_tokenizer,
52
+ get_tokenizer_from_processor,
53
+ )
54
+ from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
50
55
 
51
56
  if TYPE_CHECKING:
52
57
  from sglang.srt.managers.cache_controller import LayerDoneCounter
@@ -90,7 +95,6 @@ class TpModelWorker:
90
95
  else server_args.speculative_draft_model_revision
91
96
  ),
92
97
  is_draft_model=is_draft_worker,
93
- tp_rank=tp_rank,
94
98
  )
95
99
 
96
100
  self.model_runner = ModelRunner(
@@ -149,8 +153,8 @@ class TpModelWorker:
149
153
  assert self.max_running_requests > 0, "max_running_request is zero"
150
154
  self.max_queued_requests = server_args.max_queued_requests
151
155
  assert (
152
- self.max_queued_requests > 0
153
- ), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
156
+ self.max_queued_requests is None or self.max_queued_requests >= 1
157
+ ), "If configured, max_queued_requests must be at least 1 for any work to be scheduled."
154
158
  self.max_req_len = min(
155
159
  self.model_config.context_len - 1,
156
160
  self.max_total_num_tokens - 1,
@@ -233,10 +237,8 @@ class TpModelWorker:
233
237
  self,
234
238
  model_worker_batch: ModelWorkerBatch,
235
239
  launch_done: Optional[threading.Event] = None,
236
- skip_sample: bool = False,
237
- ) -> Tuple[
238
- Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
239
- ]:
240
+ is_verify: bool = False,
241
+ ) -> ForwardBatchOutput:
240
242
  # update the consumer index of hicache to the running batch
241
243
  self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
242
244
 
@@ -257,29 +259,31 @@ class TpModelWorker:
257
259
  if launch_done is not None:
258
260
  launch_done.set()
259
261
 
260
- if skip_sample:
261
- next_token_ids = None
262
- # For prefill-only requests, we still need to compute logprobs even when sampling is skipped
263
- if (
264
- model_worker_batch.is_prefill_only
265
- and model_worker_batch.return_logprob
266
- ):
267
- # Compute logprobs without full sampling
268
- self.model_runner.compute_logprobs_only(
269
- logits_output, model_worker_batch
270
- )
271
- else:
272
- next_token_ids = self.model_runner.sample(
262
+ skip_sample = is_verify or model_worker_batch.is_prefill_only
263
+ next_token_ids = None
264
+
265
+ if not skip_sample:
266
+ next_token_ids = self.model_runner.sample(logits_output, forward_batch)
267
+ elif model_worker_batch.return_logprob and not is_verify:
268
+ # NOTE: Compute logprobs without full sampling
269
+ self.model_runner.compute_logprobs_only(
273
270
  logits_output, model_worker_batch
274
271
  )
275
272
 
276
- return logits_output, next_token_ids, can_run_cuda_graph
273
+ return ForwardBatchOutput(
274
+ logits_output=logits_output,
275
+ next_token_ids=next_token_ids,
276
+ can_run_cuda_graph=can_run_cuda_graph,
277
+ )
277
278
  else:
278
279
  pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
279
280
  forward_batch,
280
281
  pp_proxy_tensors=pp_proxy_tensors,
281
282
  )
282
- return pp_proxy_tensors.tensors, None, can_run_cuda_graph
283
+ return ForwardBatchOutput(
284
+ pp_proxy_tensors=pp_proxy_tensors,
285
+ can_run_cuda_graph=can_run_cuda_graph,
286
+ )
283
287
 
284
288
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
285
289
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
@@ -304,6 +308,12 @@ class TpModelWorker:
304
308
  )
305
309
  return success, message
306
310
 
311
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
312
+ success, message = self.model_runner.destroy_weights_update_group(
313
+ recv_req.group_name,
314
+ )
315
+ return success, message
316
+
307
317
  def init_weights_send_group_for_remote_instance(
308
318
  self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
309
319
  ):
@@ -25,6 +25,7 @@ import psutil
25
25
  import torch
26
26
 
27
27
  from sglang.srt.managers.io_struct import (
28
+ DestroyWeightsUpdateGroupReqInput,
28
29
  GetWeightsByNameReqInput,
29
30
  InitWeightsSendGroupForRemoteInstanceReqInput,
30
31
  InitWeightsUpdateGroupReqInput,
@@ -35,10 +36,12 @@ from sglang.srt.managers.io_struct import (
35
36
  UpdateWeightsFromDistributedReqInput,
36
37
  UpdateWeightsFromTensorReqInput,
37
38
  )
39
+ from sglang.srt.managers.overlap_utils import FutureMap
38
40
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
39
41
  from sglang.srt.managers.tp_worker import TpModelWorker
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
40
43
  from sglang.srt.server_args import ServerArgs
41
- from sglang.srt.utils import DynamicGradMode, get_compiler_backend
44
+ from sglang.srt.utils import DynamicGradMode
42
45
  from sglang.utils import get_exception_traceback
43
46
 
44
47
  if TYPE_CHECKING:
@@ -47,15 +50,6 @@ if TYPE_CHECKING:
47
50
  logger = logging.getLogger(__name__)
48
51
 
49
52
 
50
- @torch.compile(dynamic=True, backend=get_compiler_backend())
51
- def resolve_future_token_ids(input_ids, future_token_ids_map):
52
- input_ids[:] = torch.where(
53
- input_ids < 0,
54
- future_token_ids_map[torch.clamp(-input_ids, min=0)],
55
- input_ids,
56
- )
57
-
58
-
59
53
  class TpModelWorkerClient:
60
54
  """A tensor parallel model worker."""
61
55
 
@@ -78,11 +72,7 @@ class TpModelWorkerClient:
78
72
  self.gpu_id = gpu_id
79
73
 
80
74
  # Init future mappings
81
- self.future_token_ids_ct = 0
82
- self.future_token_ids_limit = self.max_running_requests * 3
83
- self.future_token_ids_map = torch.empty(
84
- (self.max_running_requests * 5,), dtype=torch.int64, device=self.device
85
- )
75
+ self.future_map = FutureMap(self.max_running_requests, self.device)
86
76
 
87
77
  # Launch threads
88
78
  self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
@@ -152,7 +142,7 @@ class TpModelWorkerClient:
152
142
  batch_lists: List = [None] * 2
153
143
 
154
144
  while True:
155
- model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
145
+ model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
156
146
  if not model_worker_batch:
157
147
  break
158
148
 
@@ -168,17 +158,18 @@ class TpModelWorkerClient:
168
158
  copy_done = torch.get_device_module(self.device).Event()
169
159
 
170
160
  # Resolve future tokens in the input
171
- input_ids = model_worker_batch.input_ids
172
- resolve_future_token_ids(input_ids, self.future_token_ids_map)
161
+ self.future_map.resolve_future(model_worker_batch)
173
162
 
174
163
  # Run forward
164
+ forward_batch_output = self.worker.forward_batch_generation(
165
+ model_worker_batch,
166
+ model_worker_batch.launch_done,
167
+ )
168
+
175
169
  logits_output, next_token_ids, can_run_cuda_graph = (
176
- self.worker.forward_batch_generation(
177
- model_worker_batch,
178
- model_worker_batch.launch_done,
179
- # Skip sampling for prefill-only requests
180
- skip_sample=model_worker_batch.is_prefill_only,
181
- )
170
+ forward_batch_output.logits_output,
171
+ forward_batch_output.next_token_ids,
172
+ forward_batch_output.can_run_cuda_graph,
182
173
  )
183
174
 
184
175
  # Update the future token ids map
@@ -186,9 +177,9 @@ class TpModelWorkerClient:
186
177
  if model_worker_batch.is_prefill_only:
187
178
  # For prefill-only requests, create dummy token IDs on CPU
188
179
  next_token_ids = torch.zeros(bs, dtype=torch.long)
189
- self.future_token_ids_map[
190
- future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
191
- ] = next_token_ids
180
+
181
+ # store the future indices into future map
182
+ self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
192
183
 
193
184
  # Copy results to the CPU
194
185
  if model_worker_batch.return_logprob:
@@ -239,7 +230,7 @@ class TpModelWorkerClient:
239
230
 
240
231
  def forward_batch_generation(
241
232
  self, model_worker_batch: ModelWorkerBatch
242
- ) -> Tuple[None, torch.Tensor, bool]:
233
+ ) -> ForwardBatchOutput:
243
234
  # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
244
235
  sampling_info = model_worker_batch.sampling_info
245
236
  sampling_info.update_penalties()
@@ -254,21 +245,18 @@ class TpModelWorkerClient:
254
245
  sync_event.record(self.scheduler_stream)
255
246
 
256
247
  # Push a new batch to the queue
257
- self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
258
-
259
- # Allocate output future objects
260
248
  bs = len(model_worker_batch.seq_lens)
261
- future_next_token_ids = torch.arange(
262
- -(self.future_token_ids_ct + 1),
263
- -(self.future_token_ids_ct + 1 + bs),
264
- -1,
265
- dtype=torch.int64,
266
- device=self.device,
249
+ cur_future_map_ct = self.future_map.update_ct(bs)
250
+ self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
251
+
252
+ # get this forward batch's future token ids
253
+ future_next_token_ids = self.future_map.update_next_future(
254
+ cur_future_map_ct, bs
255
+ )
256
+ return ForwardBatchOutput(
257
+ next_token_ids=future_next_token_ids,
258
+ can_run_cuda_graph=False,
267
259
  )
268
- self.future_token_ids_ct = (
269
- self.future_token_ids_ct + bs
270
- ) % self.future_token_ids_limit
271
- return None, future_next_token_ids, False
272
260
 
273
261
  def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
274
262
  success, message = self.worker.update_weights_from_disk(recv_req)
@@ -278,6 +266,10 @@ class TpModelWorkerClient:
278
266
  success, message = self.worker.init_weights_update_group(recv_req)
279
267
  return success, message
280
268
 
269
+ def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
270
+ success, message = self.worker.destroy_weights_update_group(recv_req)
271
+ return success, message
272
+
281
273
  def init_weights_send_group_for_remote_instance(
282
274
  self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
283
275
  ):
@@ -2,11 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import multiprocessing as mp
5
- from http import HTTPStatus
6
5
  from typing import TYPE_CHECKING, Dict, List, Optional
7
6
 
8
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
- from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
8
+ from sglang.srt.managers.schedule_batch import Req
10
9
  from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
11
10
 
12
11
  if TYPE_CHECKING:
@@ -97,46 +96,3 @@ def get_logprob_from_pp_outputs(
97
96
  ]
98
97
 
99
98
  return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
100
-
101
-
102
- class DPBalanceMeta:
103
- """
104
- This class will be use in scheduler and dp controller
105
- """
106
-
107
- def __init__(self, num_workers: int):
108
- self.num_workers = num_workers
109
- self._manager = mp.Manager()
110
- self.mutex = self._manager.Lock()
111
-
112
- init_local_tokens = [0] * self.num_workers
113
- init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
114
-
115
- self.shared_state = self._manager.Namespace()
116
- self.shared_state.local_tokens = self._manager.list(init_local_tokens)
117
- self.shared_state.onfly_info = self._manager.list(init_onfly_info)
118
-
119
- def destructor(self):
120
- # we must destructor this class manually
121
- self._manager.shutdown()
122
-
123
- def get_shared_onfly(self) -> List[Dict[int, int]]:
124
- return [dict(d) for d in self.shared_state.onfly_info]
125
-
126
- def set_shared_onfly_info(self, data: List[Dict[int, int]]):
127
- self.shared_state.onfly_info = data
128
-
129
- def get_shared_local_tokens(self) -> List[int]:
130
- return list(self.shared_state.local_tokens)
131
-
132
- def set_shared_local_tokens(self, data: List[int]):
133
- self.shared_state.local_tokens = data
134
-
135
- def __getstate__(self):
136
- state = self.__dict__.copy()
137
- del state["_manager"]
138
- return state
139
-
140
- def __setstate__(self, state):
141
- self.__dict__.update(state)
142
- self._manager = None
@@ -27,7 +27,7 @@ import triton
27
27
  import triton.language as tl
28
28
 
29
29
  from sglang.srt.mem_cache.memory_pool import SWAKVPool
30
- from sglang.srt.utils import get_bool_env_var, next_power_of_2
30
+ from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
31
31
 
32
32
  if TYPE_CHECKING:
33
33
  from sglang.srt.mem_cache.memory_pool import KVCache
@@ -294,7 +294,6 @@ def alloc_extend_kernel(
294
294
  last_loc_ptr,
295
295
  free_page_ptr,
296
296
  out_indices,
297
- ret_values,
298
297
  bs_upper: tl.constexpr,
299
298
  page_size: tl.constexpr,
300
299
  max_num_extend_tokens: tl.constexpr,
@@ -323,13 +322,6 @@ def alloc_extend_kernel(
323
322
  sum_num_new_pages = tl.sum(num_new_pages)
324
323
  new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
325
324
 
326
- # Return value
327
- if pid == tl.num_programs(0) - 1:
328
- merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
329
- tl.int64
330
- )
331
- tl.store(ret_values, merged_value)
332
-
333
325
  # Part 1: fill the old partial page
334
326
  last_loc = tl.load(last_loc_ptr + pid)
335
327
  num_part1 = (
@@ -381,7 +373,6 @@ def alloc_decode_kernel(
381
373
  last_loc_ptr,
382
374
  free_page_ptr,
383
375
  out_indices,
384
- ret_values,
385
376
  bs_upper: tl.constexpr,
386
377
  page_size: tl.constexpr,
387
378
  ):
@@ -404,10 +395,6 @@ def alloc_decode_kernel(
404
395
  sum_num_new_pages = tl.sum(num_new_pages)
405
396
  new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
406
397
 
407
- # Return value
408
- if pid == tl.num_programs(0) - 1:
409
- tl.store(ret_values, sum_num_new_pages)
410
-
411
398
  if num_page_start_loc_self == 0:
412
399
  last_loc = tl.load(last_loc_ptr + pid)
413
400
  tl.store(out_indices + pid, last_loc + 1)
@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
438
425
  super().__init__(size, page_size, dtype, device, kvcache, need_sort)
439
426
  self.num_pages = size // page_size
440
427
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
441
- self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
442
428
  self.seen_max_num_extend_tokens_next_power_of_2 = 1
443
429
  self.clear()
444
430
 
@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
468
454
  def alloc_extend(
469
455
  self,
470
456
  prefix_lens: torch.Tensor,
457
+ prefix_lens_cpu: torch.Tensor,
471
458
  seq_lens: torch.Tensor,
459
+ seq_lens_cpu: torch.Tensor,
472
460
  last_loc: torch.Tensor,
473
461
  extend_num_tokens: int,
474
462
  ):
@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
497
485
  last_loc,
498
486
  self.free_pages,
499
487
  out_indices,
500
- self.ret_values,
501
488
  next_power_of_2(bs),
502
489
  self.page_size,
503
490
  self.seen_max_num_extend_tokens_next_power_of_2,
@@ -506,8 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
506
493
  if self.debug_mode:
507
494
  assert len(torch.unique(out_indices)) == len(out_indices)
508
495
 
509
- merged_value = self.ret_values.item()
510
- num_new_pages = merged_value >> 32
496
+ num_new_pages = get_num_new_pages(
497
+ seq_lens=seq_lens_cpu,
498
+ page_size=self.page_size,
499
+ prefix_lens=prefix_lens_cpu,
500
+ )
511
501
  if num_new_pages > len(self.free_pages):
512
502
  return None
513
503
 
@@ -517,6 +507,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
517
507
  def alloc_decode(
518
508
  self,
519
509
  seq_lens: torch.Tensor,
510
+ seq_lens_cpu: torch.Tensor,
520
511
  last_loc: torch.Tensor,
521
512
  ):
522
513
  if self.debug_mode:
@@ -534,7 +525,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
534
525
  last_loc,
535
526
  self.free_pages,
536
527
  out_indices,
537
- self.ret_values,
538
528
  next_power_of_2(bs),
539
529
  self.page_size,
540
530
  )
@@ -542,7 +532,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
542
532
  if self.debug_mode:
543
533
  assert len(torch.unique(out_indices)) == len(out_indices)
544
534
 
545
- num_new_pages = self.ret_values.item()
535
+ num_new_pages = get_num_new_pages(
536
+ seq_lens=seq_lens_cpu,
537
+ page_size=self.page_size,
538
+ decode=True,
539
+ )
546
540
  if num_new_pages > len(self.free_pages):
547
541
  return None
548
542
 
@@ -1,13 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
4
-
5
3
  import torch
6
4
 
7
5
  from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
8
-
9
- if TYPE_CHECKING:
10
- from sglang.srt.mem_cache.memory_pool import KVCache
6
+ from sglang.srt.utils import get_num_new_pages
11
7
 
12
8
 
13
9
  def alloc_extend_kernel_ascend(
@@ -69,7 +65,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
69
65
  def alloc_extend(
70
66
  self,
71
67
  prefix_lens: torch.Tensor,
68
+ prefix_lens_cpu: torch.Tensor,
72
69
  seq_lens: torch.Tensor,
70
+ seq_lens_cpu: torch.Tensor,
73
71
  last_loc: torch.Tensor,
74
72
  extend_num_tokens: int,
75
73
  ):
@@ -79,42 +77,54 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
79
77
  )
80
78
 
81
79
  num_new_pages = (
82
- (
83
- (seq_lens + self.page_size - 1) // self.page_size
84
- - (prefix_lens + self.page_size - 1) // self.page_size
85
- )
86
- .sum()
87
- .item()
88
- )
89
- if self.need_sort and num_new_pages > len(self.free_pages):
80
+ (seq_lens + self.page_size - 1) // self.page_size
81
+ - (prefix_lens + self.page_size - 1) // self.page_size
82
+ ).sum()
83
+ num_new_pages_item = num_new_pages.item()
84
+ if self.need_sort and num_new_pages_item > len(self.free_pages):
90
85
  self.merge_and_sort_free()
91
86
 
92
- if num_new_pages > len(self.free_pages):
87
+ if num_new_pages_item > len(self.free_pages):
93
88
  return None
94
89
 
95
90
  out_indices = torch.empty(
96
- (extend_num_tokens,), dtype=torch.int32, device=self.device
91
+ (extend_num_tokens,), dtype=torch.int64, device=self.device
97
92
  )
98
93
 
99
- alloc_extend_kernel_ascend(
100
- prefix_lens,
101
- seq_lens,
102
- last_loc,
103
- self.free_pages,
104
- out_indices,
105
- self.page_size,
106
- self.device,
107
- )
94
+ if num_new_pages_item < 200:
95
+ import sgl_kernel_npu
96
+
97
+ torch.ops.npu.alloc_extend(
98
+ prefix_lens,
99
+ seq_lens,
100
+ last_loc,
101
+ self.free_pages,
102
+ self.page_size,
103
+ out_indices,
104
+ num_new_pages,
105
+ )
106
+
107
+ else:
108
+ alloc_extend_kernel_ascend(
109
+ prefix_lens,
110
+ seq_lens,
111
+ last_loc,
112
+ self.free_pages,
113
+ out_indices,
114
+ self.page_size,
115
+ self.device,
116
+ )
108
117
 
109
118
  if self.debug_mode:
110
119
  assert len(torch.unique(out_indices)) == len(out_indices)
111
120
 
112
- self.free_pages = self.free_pages[num_new_pages:]
121
+ self.free_pages = self.free_pages[num_new_pages_item:]
113
122
  return out_indices
114
123
 
115
124
  def alloc_decode(
116
125
  self,
117
126
  seq_lens: torch.Tensor,
127
+ seq_lens_cpu: torch.Tensor,
118
128
  last_loc: torch.Tensor,
119
129
  ):
120
130
  if self.debug_mode:
@@ -122,8 +132,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
122
132
  (last_loc + 2) % self.page_size == seq_lens % self.page_size
123
133
  )
124
134
 
125
- need_new_pages = (seq_lens % self.page_size == 1).int()
126
- num_new_pages = need_new_pages.sum().item()
135
+ num_new_pages = get_num_new_pages(
136
+ seq_lens=seq_lens_cpu,
137
+ page_size=self.page_size,
138
+ decode=True,
139
+ )
127
140
 
128
141
  if num_new_pages > len(self.free_pages):
129
142
  self.merge_and_sort_free()
@@ -131,6 +144,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
131
144
  if num_new_pages > len(self.free_pages):
132
145
  return None
133
146
 
147
+ need_new_pages = (seq_lens % self.page_size == 1).int()
134
148
  end_new_pages = torch.cumsum(need_new_pages, 0)
135
149
  start_new_pages = end_new_pages - need_new_pages
136
150
  if num_new_pages == 0:
@@ -36,7 +36,7 @@ class BasePrefixCache(ABC):
36
36
  pass
37
37
 
38
38
  @abstractmethod
39
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
39
+ def match_prefix(self, key: Any, **kwargs) -> MatchResult:
40
40
  pass
41
41
 
42
42
  @abstractmethod
@@ -28,6 +28,13 @@ class ChunkCache(BasePrefixCache):
28
28
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
29
29
  self.page_size = page_size
30
30
 
31
+ # NOTE (csy): this is to determine if a cache has prefix matching feature.
32
+ # Chunk cache always return True to indicate no prefix matching.
33
+ # TODO (csy): Using a prefix cache trait to replace this
34
+ @property
35
+ def disable(self):
36
+ return True
37
+
31
38
  def reset(self):
32
39
  pass
33
40
 
@@ -38,7 +45,7 @@ class ChunkCache(BasePrefixCache):
38
45
  last_host_node=None,
39
46
  )
40
47
 
41
- def cache_finished_req(self, req: Req):
48
+ def cache_finished_req(self, req: Req, insert: bool = True):
42
49
  kv_indices = self.req_to_token_pool.req_to_token[
43
50
  req.req_pool_idx,
44
51
  # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, List, Tuple, Union
5
+
6
+ if TYPE_CHECKING:
7
+ from sglang.srt.mem_cache.radix_cache import TreeNode
8
+
9
+
10
+ class EvictionStrategy(ABC):
11
+ @abstractmethod
12
+ def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
13
+ pass
14
+
15
+
16
+ class LRUStrategy(EvictionStrategy):
17
+ def get_priority(self, node: "TreeNode") -> float:
18
+ return node.last_access_time
19
+
20
+
21
+ class LFUStrategy(EvictionStrategy):
22
+ def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
23
+ return (node.hit_count, node.last_access_time)