sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (395) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/lang/interpreter.py +1 -1
  7. sglang/launch_server.py +14 -0
  8. sglang/profiler.py +2 -2
  9. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  10. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  11. sglang/srt/configs/__init__.py +8 -0
  12. sglang/srt/configs/device_config.py +3 -1
  13. sglang/srt/configs/dots_ocr.py +64 -0
  14. sglang/srt/configs/dots_vlm.py +139 -0
  15. sglang/srt/configs/falcon_h1.py +360 -0
  16. sglang/srt/configs/internvl.py +6 -0
  17. sglang/srt/configs/load_config.py +9 -0
  18. sglang/srt/configs/model_config.py +181 -82
  19. sglang/srt/configs/qwen3_next.py +326 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +71 -19
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +326 -53
  44. sglang/srt/disaggregation/prefill.py +36 -17
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +192 -113
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  52. sglang/srt/entrypoints/grpc_server.py +810 -0
  53. sglang/srt/entrypoints/http_server.py +132 -57
  54. sglang/srt/entrypoints/openai/protocol.py +115 -7
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +207 -58
  57. sglang/srt/entrypoints/openai/serving_completions.py +17 -4
  58. sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +49 -4
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/environ.py +285 -0
  63. sglang/srt/eplb/eplb_manager.py +2 -2
  64. sglang/srt/eplb/expert_distribution.py +26 -13
  65. sglang/srt/eplb/expert_location.py +38 -8
  66. sglang/srt/eplb/expert_location_updater.py +1 -1
  67. sglang/srt/function_call/base_format_detector.py +3 -6
  68. sglang/srt/function_call/ebnf_composer.py +11 -9
  69. sglang/srt/function_call/function_call_parser.py +9 -2
  70. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  71. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  72. sglang/srt/function_call/json_array_parser.py +63 -0
  73. sglang/srt/function_call/kimik2_detector.py +17 -4
  74. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  75. sglang/srt/function_call/utils.py +96 -5
  76. sglang/srt/grpc/__init__.py +1 -0
  77. sglang/srt/grpc/compile_proto.py +245 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  81. sglang/srt/layers/activation.py +143 -9
  82. sglang/srt/layers/attention/aiter_backend.py +106 -82
  83. sglang/srt/layers/attention/ascend_backend.py +115 -9
  84. sglang/srt/layers/attention/attention_registry.py +206 -0
  85. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  86. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  87. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  88. sglang/srt/layers/attention/fla/chunk.py +242 -0
  89. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  90. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  91. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  92. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  93. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  94. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  95. sglang/srt/layers/attention/fla/index.py +37 -0
  96. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  97. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  98. sglang/srt/layers/attention/fla/op.py +66 -0
  99. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  100. sglang/srt/layers/attention/fla/utils.py +331 -0
  101. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  102. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  103. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  104. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  105. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  106. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  107. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  108. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  109. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  111. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  112. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  113. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  114. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  115. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  121. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  122. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  123. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  124. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  125. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  126. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  127. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  128. sglang/srt/layers/attention/nsa/utils.py +24 -0
  129. sglang/srt/layers/attention/nsa_backend.py +887 -0
  130. sglang/srt/layers/attention/tbo_backend.py +6 -6
  131. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  132. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  133. sglang/srt/layers/attention/triton_backend.py +57 -7
  134. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  135. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  136. sglang/srt/layers/attention/vision.py +58 -0
  137. sglang/srt/layers/attention/wave_backend.py +4 -4
  138. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  139. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  140. sglang/srt/layers/communicator.py +53 -7
  141. sglang/srt/layers/dp_attention.py +41 -2
  142. sglang/srt/layers/elementwise.py +3 -1
  143. sglang/srt/layers/layernorm.py +34 -15
  144. sglang/srt/layers/linear.py +55 -7
  145. sglang/srt/layers/logits_processor.py +44 -12
  146. sglang/srt/layers/moe/__init__.py +2 -1
  147. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  148. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  149. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  150. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  151. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  167. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  169. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  170. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  171. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  172. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  173. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  174. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  175. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  176. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  177. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  178. sglang/srt/layers/moe/topk.py +30 -9
  179. sglang/srt/layers/moe/utils.py +22 -7
  180. sglang/srt/layers/parameter.py +23 -6
  181. sglang/srt/layers/quantization/awq.py +19 -7
  182. sglang/srt/layers/quantization/base_config.py +11 -6
  183. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  184. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  185. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  186. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  187. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  188. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  189. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  190. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  191. sglang/srt/layers/quantization/fp8.py +78 -49
  192. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  193. sglang/srt/layers/quantization/gptq.py +25 -17
  194. sglang/srt/layers/quantization/modelopt_quant.py +225 -57
  195. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  196. sglang/srt/layers/quantization/mxfp4.py +77 -42
  197. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  199. sglang/srt/layers/quantization/quark/utils.py +97 -0
  200. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  201. sglang/srt/layers/quantization/unquant.py +135 -47
  202. sglang/srt/layers/quantization/w4afp8.py +26 -17
  203. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  204. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  205. sglang/srt/layers/rocm_linear_utils.py +44 -0
  206. sglang/srt/layers/rotary_embedding.py +78 -49
  207. sglang/srt/layers/sampler.py +213 -21
  208. sglang/srt/layers/utils.py +23 -0
  209. sglang/srt/lora/backend/base_backend.py +50 -8
  210. sglang/srt/lora/backend/chunked_backend.py +348 -0
  211. sglang/srt/lora/backend/triton_backend.py +99 -5
  212. sglang/srt/lora/layers.py +32 -0
  213. sglang/srt/lora/lora.py +8 -3
  214. sglang/srt/lora/lora_manager.py +52 -118
  215. sglang/srt/lora/mem_pool.py +25 -11
  216. sglang/srt/lora/triton_ops/__init__.py +4 -0
  217. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  219. sglang/srt/lora/utils.py +22 -11
  220. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  221. sglang/srt/managers/cache_controller.py +215 -314
  222. sglang/srt/managers/data_parallel_controller.py +115 -80
  223. sglang/srt/managers/detokenizer_manager.py +19 -15
  224. sglang/srt/managers/disagg_service.py +46 -0
  225. sglang/srt/managers/io_struct.py +340 -109
  226. sglang/srt/managers/mm_utils.py +44 -6
  227. sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
  228. sglang/srt/managers/multimodal_processor.py +1 -2
  229. sglang/srt/managers/overlap_utils.py +53 -0
  230. sglang/srt/managers/schedule_batch.py +240 -138
  231. sglang/srt/managers/schedule_policy.py +147 -19
  232. sglang/srt/managers/scheduler.py +501 -304
  233. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  234. sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
  235. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  236. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  237. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  238. sglang/srt/managers/template_manager.py +3 -3
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +321 -632
  241. sglang/srt/managers/tp_worker.py +81 -22
  242. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  243. sglang/srt/managers/utils.py +1 -45
  244. sglang/srt/mem_cache/allocator.py +15 -21
  245. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  246. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  247. sglang/srt/mem_cache/chunk_cache.py +8 -1
  248. sglang/srt/mem_cache/evict_policy.py +23 -0
  249. sglang/srt/mem_cache/hicache_storage.py +58 -34
  250. sglang/srt/mem_cache/hiradix_cache.py +227 -80
  251. sglang/srt/mem_cache/memory_pool.py +535 -58
  252. sglang/srt/mem_cache/memory_pool_host.py +239 -223
  253. sglang/srt/mem_cache/radix_cache.py +222 -73
  254. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  255. sglang/srt/mem_cache/storage/__init__.py +10 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  257. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  258. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  259. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  260. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  261. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  262. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  263. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
  264. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  265. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  266. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
  267. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  268. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  269. sglang/srt/metrics/collector.py +519 -132
  270. sglang/srt/metrics/func_timer.py +2 -7
  271. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  272. sglang/srt/metrics/utils.py +55 -0
  273. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  274. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  275. sglang/srt/model_executor/forward_batch_info.py +98 -57
  276. sglang/srt/model_executor/model_runner.py +433 -158
  277. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  278. sglang/srt/model_loader/__init__.py +9 -3
  279. sglang/srt/model_loader/loader.py +133 -5
  280. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  281. sglang/srt/model_loader/weight_utils.py +158 -3
  282. sglang/srt/models/apertus.py +686 -0
  283. sglang/srt/models/bailing_moe.py +820 -217
  284. sglang/srt/models/bailing_moe_nextn.py +168 -0
  285. sglang/srt/models/deepseek_nextn.py +6 -1
  286. sglang/srt/models/deepseek_v2.py +833 -152
  287. sglang/srt/models/dots_ocr.py +173 -0
  288. sglang/srt/models/dots_vlm.py +174 -0
  289. sglang/srt/models/dots_vlm_vit.py +337 -0
  290. sglang/srt/models/ernie4.py +1 -1
  291. sglang/srt/models/falcon_h1.py +576 -0
  292. sglang/srt/models/gemma3_causal.py +0 -2
  293. sglang/srt/models/gemma3_mm.py +1 -1
  294. sglang/srt/models/gemma3n_mm.py +2 -2
  295. sglang/srt/models/glm4_moe.py +14 -5
  296. sglang/srt/models/glm4_moe_nextn.py +2 -2
  297. sglang/srt/models/glm4v.py +5 -3
  298. sglang/srt/models/glm4v_moe.py +4 -1
  299. sglang/srt/models/gpt_oss.py +8 -31
  300. sglang/srt/models/internvl.py +28 -0
  301. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  302. sglang/srt/models/llama.py +4 -0
  303. sglang/srt/models/llama4.py +9 -0
  304. sglang/srt/models/llama_eagle3.py +13 -0
  305. sglang/srt/models/longcat_flash.py +3 -3
  306. sglang/srt/models/longcat_flash_nextn.py +1 -1
  307. sglang/srt/models/minicpmv.py +165 -3
  308. sglang/srt/models/mllama4.py +40 -4
  309. sglang/srt/models/opt.py +637 -0
  310. sglang/srt/models/qwen2_5_vl.py +29 -5
  311. sglang/srt/models/qwen2_audio.py +1 -1
  312. sglang/srt/models/qwen2_moe.py +124 -14
  313. sglang/srt/models/qwen2_vl.py +1 -1
  314. sglang/srt/models/qwen3.py +26 -5
  315. sglang/srt/models/qwen3_moe.py +71 -12
  316. sglang/srt/models/qwen3_next.py +1069 -0
  317. sglang/srt/models/qwen3_next_mtp.py +112 -0
  318. sglang/srt/models/qwen3_vl.py +787 -0
  319. sglang/srt/models/qwen3_vl_moe.py +471 -0
  320. sglang/srt/models/registry.py +15 -3
  321. sglang/srt/models/sarashina2_vision.py +269 -0
  322. sglang/srt/models/solar.py +505 -0
  323. sglang/srt/models/starcoder2.py +357 -0
  324. sglang/srt/models/step3_vl.py +1 -1
  325. sglang/srt/models/torch_native_llama.py +10 -3
  326. sglang/srt/models/utils.py +51 -0
  327. sglang/srt/multimodal/processors/base_processor.py +15 -7
  328. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  329. sglang/srt/multimodal/processors/glm4v.py +9 -9
  330. sglang/srt/multimodal/processors/internvl.py +153 -129
  331. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  332. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  333. sglang/srt/offloader.py +27 -3
  334. sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
  335. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  336. sglang/srt/sampling/sampling_batch_info.py +38 -17
  337. sglang/srt/sampling/sampling_params.py +7 -0
  338. sglang/srt/server_args.py +1030 -254
  339. sglang/srt/server_args_config_parser.py +146 -0
  340. sglang/srt/single_batch_overlap.py +151 -0
  341. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  342. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  343. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  344. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  345. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  346. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  347. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  348. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  349. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  350. sglang/srt/speculative/eagle_worker.py +253 -136
  351. sglang/srt/speculative/ngram_utils.py +428 -0
  352. sglang/srt/speculative/ngram_worker.py +245 -0
  353. sglang/srt/speculative/spec_info.py +52 -0
  354. sglang/srt/speculative/spec_utils.py +606 -0
  355. sglang/srt/speculative/standalone_worker.py +109 -0
  356. sglang/srt/torch_memory_saver_adapter.py +5 -7
  357. sglang/srt/tracing/trace.py +578 -0
  358. sglang/srt/two_batch_overlap.py +8 -5
  359. sglang/srt/utils/__init__.py +2 -0
  360. sglang/srt/{utils.py → utils/common.py} +445 -77
  361. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  362. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  363. sglang/srt/utils/rpd_utils.py +452 -0
  364. sglang/srt/utils/slow_rank_detector.py +71 -0
  365. sglang/srt/warmup.py +8 -4
  366. sglang/srt/weight_sync/utils.py +2 -2
  367. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  368. sglang/test/few_shot_gsm8k.py +1 -0
  369. sglang/test/get_logits_ut.py +57 -0
  370. sglang/test/run_eval.py +79 -11
  371. sglang/test/runners.py +5 -1
  372. sglang/test/simple_eval_common.py +5 -2
  373. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  374. sglang/test/test_block_fp8.py +2 -2
  375. sglang/test/test_cutlass_moe.py +24 -6
  376. sglang/test/test_deterministic.py +297 -0
  377. sglang/test/test_disaggregation_utils.py +77 -0
  378. sglang/test/test_fp4_moe.py +370 -1
  379. sglang/test/test_programs.py +1 -1
  380. sglang/test/test_utils.py +383 -5
  381. sglang/utils.py +22 -1
  382. sglang/version.py +1 -1
  383. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  384. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
  385. sglang/srt/disaggregation/launch_lb.py +0 -118
  386. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  387. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  388. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  389. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  390. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  391. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  392. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  393. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  394. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  395. {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,118 +0,0 @@
1
- import argparse
2
- import dataclasses
3
-
4
- from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
5
-
6
-
7
- @dataclasses.dataclass
8
- class LBArgs:
9
- host: str = "0.0.0.0"
10
- port: int = 8000
11
- policy: str = "random"
12
- prefill_infos: list = dataclasses.field(default_factory=list)
13
- decode_infos: list = dataclasses.field(default_factory=list)
14
- log_interval: int = 5
15
- timeout: int = 600
16
-
17
- @staticmethod
18
- def add_cli_args(parser: argparse.ArgumentParser):
19
- parser.add_argument(
20
- "--host",
21
- type=str,
22
- default=LBArgs.host,
23
- help=f"Host to bind the server (default: {LBArgs.host})",
24
- )
25
- parser.add_argument(
26
- "--port",
27
- type=int,
28
- default=LBArgs.port,
29
- help=f"Port to bind the server (default: {LBArgs.port})",
30
- )
31
- parser.add_argument(
32
- "--policy",
33
- type=str,
34
- default=LBArgs.policy,
35
- choices=["random", "po2"],
36
- help=f"Policy to use for load balancing (default: {LBArgs.policy})",
37
- )
38
- parser.add_argument(
39
- "--prefill",
40
- type=str,
41
- default=[],
42
- nargs="+",
43
- help="URLs for prefill servers",
44
- )
45
- parser.add_argument(
46
- "--decode",
47
- type=str,
48
- default=[],
49
- nargs="+",
50
- help="URLs for decode servers",
51
- )
52
- parser.add_argument(
53
- "--prefill-bootstrap-ports",
54
- type=int,
55
- nargs="+",
56
- help="Bootstrap ports for prefill servers",
57
- )
58
- parser.add_argument(
59
- "--log-interval",
60
- type=int,
61
- default=LBArgs.log_interval,
62
- help=f"Log interval in seconds (default: {LBArgs.log_interval})",
63
- )
64
- parser.add_argument(
65
- "--timeout",
66
- type=int,
67
- default=LBArgs.timeout,
68
- help=f"Timeout in seconds (default: {LBArgs.timeout})",
69
- )
70
-
71
- @classmethod
72
- def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
73
- bootstrap_ports = args.prefill_bootstrap_ports
74
- if bootstrap_ports is None:
75
- bootstrap_ports = [None] * len(args.prefill)
76
- elif len(bootstrap_ports) == 1:
77
- bootstrap_ports = bootstrap_ports * len(args.prefill)
78
- else:
79
- if len(bootstrap_ports) != len(args.prefill):
80
- raise ValueError(
81
- "Number of prefill URLs must match number of bootstrap ports"
82
- )
83
-
84
- prefill_infos = [
85
- (url, port) for url, port in zip(args.prefill, bootstrap_ports)
86
- ]
87
-
88
- return cls(
89
- host=args.host,
90
- port=args.port,
91
- policy=args.policy,
92
- prefill_infos=prefill_infos,
93
- decode_infos=args.decode,
94
- log_interval=args.log_interval,
95
- timeout=args.timeout,
96
- )
97
-
98
-
99
- def main():
100
- parser = argparse.ArgumentParser(
101
- description="PD Disaggregation Load Balancer Server"
102
- )
103
- LBArgs.add_cli_args(parser)
104
- args = parser.parse_args()
105
- lb_args = LBArgs.from_cli_args(args)
106
-
107
- prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
108
- run(
109
- prefill_configs,
110
- lb_args.decode_infos,
111
- lb_args.host,
112
- lb_args.port,
113
- lb_args.timeout,
114
- )
115
-
116
-
117
- if __name__ == "__main__":
118
- main()
@@ -1,421 +0,0 @@
1
- """Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
2
-
3
- import heapq
4
- import time
5
- from collections import defaultdict
6
- from typing import TYPE_CHECKING, Any, List, Optional
7
-
8
- import torch
9
-
10
- from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
- from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
12
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
13
-
14
- if TYPE_CHECKING:
15
- from sglang.srt.managers.schedule_batch import Req
16
- else:
17
- Req = Any # Placeholder for Req type when not type checking
18
-
19
-
20
- class LoRAKey:
21
-
22
- def __init__(self, lora_id: str, token_ids: List[int]):
23
- self.lora_id = (
24
- lora_id # lora_id of adaptor, should be hash value of adaptor path
25
- )
26
- self.token_ids = token_ids # token_ids of the key
27
-
28
- def __len__(self):
29
- return len(self.token_ids)
30
-
31
-
32
- def get_child_key(key: LoRAKey):
33
- # Here the key of children dict is the hash of lora_id + str(token_ids[0])
34
- # So the child key can be matched only when lora_id and token_ids[0] are the same
35
- if key.lora_id is None:
36
- return hash(str(key.token_ids[0]))
37
- else:
38
- return hash(key.lora_id + str(key.token_ids[0]))
39
-
40
-
41
- class LoRATreeNode:
42
-
43
- counter = 0
44
-
45
- def __init__(self, id: Optional[int] = None):
46
- self.children = defaultdict(LoRATreeNode)
47
- self.parent: LoRATreeNode = None
48
- self.key: LoRAKey = None
49
- self.value: Optional[torch.Tensor] = None
50
- self.lock_ref = 0
51
- self.last_access_time = time.monotonic()
52
-
53
- self.id = LoRATreeNode.counter if id is None else id
54
- LoRATreeNode.counter += 1
55
-
56
- @property
57
- def evicted(self):
58
- return self.value is None
59
-
60
- def __lt__(self, other: "LoRATreeNode"):
61
- return self.last_access_time < other.last_access_time
62
-
63
-
64
- def _key_match(key0: LoRAKey, key1: LoRAKey):
65
- if key0.lora_id != key1.lora_id:
66
- raise ValueError(
67
- f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
68
- )
69
- i = 0
70
- for k0, k1 in zip(key0.token_ids, key1.token_ids):
71
- if k0 != k1:
72
- break
73
- i += 1
74
- return i
75
-
76
-
77
- class LoRARadixCache(BasePrefixCache):
78
-
79
- def __init__(
80
- self,
81
- req_to_token_pool: ReqToTokenPool,
82
- token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
83
- page_size: int,
84
- disable: bool = False,
85
- ):
86
- if page_size > 1:
87
- raise ValueError("LoRARadixCache currently only supports page_size = 1")
88
-
89
- if token_to_kv_pool_allocator is None:
90
- raise ValueError(
91
- "token_to_kv_pool_allocator is required to run LoraRadixCache"
92
- )
93
-
94
- self.req_to_token_pool = req_to_token_pool
95
- self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
96
- self.page_size = page_size
97
- self.disable = disable
98
- self.device = self.token_to_kv_pool_allocator.device
99
-
100
- self.key_match_fn = _key_match
101
- self.get_child_key_fn = get_child_key
102
- self.reset()
103
-
104
- def reset(self):
105
- self.root_node = LoRATreeNode()
106
- self.root_node.key = LoRAKey(lora_id="", token_ids=[])
107
- self.root_node.value = None
108
- self.evictable_size_ = 0
109
- self.protected_size_ = 0
110
-
111
- def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
112
- raise ValueError(
113
- "LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
114
- )
115
-
116
- def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
117
- """Find the matching prefix from the lora radix tree.
118
- Args:
119
- key: A LoRAKey to find a matching prefix.
120
- Returns:
121
- A tuple of a tensor of matching prefix token IDs and
122
- the last node that contains the prefix values. Note that
123
- this API can modify the internal state of the Radix tree.
124
- The last node create a new child if the prefix is shorter
125
- than the last node's value.
126
- """
127
- if self.disable or len(key) == 0:
128
- return MatchResult(
129
- device_indices=torch.empty(
130
- (0,),
131
- dtype=torch.int64,
132
- device=self.device,
133
- ),
134
- last_device_node=self.root_node,
135
- last_host_node=self.root_node,
136
- )
137
-
138
- value, last_node = self._match_prefix_helper(self.root_node, key)
139
- if value:
140
- value = torch.cat(value)
141
- else:
142
- value = torch.empty((0,), dtype=torch.int64, device=self.device)
143
- return MatchResult(
144
- device_indices=value,
145
- last_device_node=last_node,
146
- last_host_node=last_node,
147
- )
148
-
149
- def insert(self, key: LoRAKey, value=None):
150
- if self.disable:
151
- return 0
152
-
153
- if value is None:
154
- value = [x for x in key.token_ids]
155
- return self._insert_helper(self.root_node, key, value)
156
-
157
- def cache_finished_req(self, req: Req):
158
- """Cache request when it finishes."""
159
- if self.disable:
160
- kv_indices = self.req_to_token_pool.req_to_token[
161
- req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
162
- ]
163
- self.token_to_kv_pool_allocator.free(kv_indices)
164
- self.req_to_token_pool.free(req.req_pool_idx)
165
- return
166
-
167
- token_ids = (req.origin_input_ids + req.output_ids)[:-1]
168
- kv_indices = self.req_to_token_pool.req_to_token[
169
- req.req_pool_idx, : len(token_ids)
170
- ]
171
-
172
- page_aligned_len = len(kv_indices)
173
- page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
174
-
175
- # Radix Cache takes one ref in memory pool
176
- lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
177
- new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
178
- self.token_to_kv_pool_allocator.free(
179
- kv_indices[len(req.prefix_indices) : new_prefix_len]
180
- )
181
-
182
- # Remove req slot release the cache lock
183
- self.req_to_token_pool.free(req.req_pool_idx)
184
- self.dec_lock_ref(req.last_node)
185
-
186
- def cache_unfinished_req(self, req: Req, chunked=False):
187
- """Cache request when it is unfinished."""
188
- if self.disable:
189
- return
190
-
191
- token_ids = req.fill_ids
192
- kv_indices = self.req_to_token_pool.req_to_token[
193
- req.req_pool_idx, : len(token_ids)
194
- ]
195
-
196
- page_aligned_len = len(kv_indices)
197
- page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
198
- page_aligned_token_ids = token_ids[:page_aligned_len]
199
-
200
- # Radix Cache takes one ref in memory pool
201
- inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
202
- new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
203
- self.token_to_kv_pool_allocator.free(
204
- kv_indices[len(req.prefix_indices) : new_prefix_len]
205
- )
206
-
207
- # The prefix indices could be updated, reuse it
208
- new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
209
- self.req_to_token_pool.write(
210
- (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
211
- new_indices[len(req.prefix_indices) :],
212
- )
213
-
214
- self.dec_lock_ref(req.last_node)
215
- self.inc_lock_ref(new_last_node)
216
-
217
- # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
218
- req.prefix_indices = new_indices
219
- req.last_node = new_last_node
220
-
221
- def pretty_print(self):
222
- self._print_helper(self.root_node, 0)
223
- print(f"#tokens: {self.total_size()}")
224
-
225
- def total_size(self):
226
- return self._total_size_helper()
227
-
228
- def evict(self, num_tokens: int):
229
- if self.disable:
230
- return
231
-
232
- leaves = self._collect_leaves()
233
- heapq.heapify(leaves)
234
-
235
- num_evicted = 0
236
- while num_evicted < num_tokens and len(leaves):
237
- x = heapq.heappop(leaves)
238
-
239
- if x == self.root_node:
240
- break
241
- if x.lock_ref > 0:
242
- continue
243
-
244
- self.token_to_kv_pool_allocator.free(x.value)
245
- num_evicted += len(x.value)
246
- self._delete_leaf(x)
247
-
248
- if len(x.parent.children) == 0:
249
- heapq.heappush(leaves, x.parent)
250
-
251
- def inc_lock_ref(self, node: LoRATreeNode):
252
- if self.disable:
253
- return 0
254
-
255
- delta = 0
256
- while node != self.root_node:
257
- if node.lock_ref == 0:
258
- self.evictable_size_ -= len(node.value)
259
- self.protected_size_ += len(node.value)
260
- delta -= len(node.value)
261
- node.lock_ref += 1
262
- node = node.parent
263
- return delta
264
-
265
- def dec_lock_ref(self, node: LoRATreeNode):
266
- if self.disable:
267
- return 0
268
-
269
- delta = 0
270
- while node != self.root_node:
271
- if node.lock_ref == 1:
272
- self.evictable_size_ += len(node.value)
273
- self.protected_size_ -= len(node.value)
274
- delta += len(node.value)
275
- node.lock_ref -= 1
276
- node = node.parent
277
- return delta
278
-
279
- def evictable_size(self):
280
- return self.evictable_size_
281
-
282
- def protected_size(self):
283
- # protected size refers to the size of the cache that is locked
284
- return self.protected_size_
285
-
286
- def all_values_flatten(self):
287
- values = []
288
-
289
- def _dfs_helper(node: LoRATreeNode):
290
- for _, child in node.children.items():
291
- values.append(child.value)
292
- _dfs_helper(child)
293
-
294
- _dfs_helper(self.root_node)
295
- return torch.cat(values)
296
-
297
- ##### Internal Helper Functions #####
298
-
299
- def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
300
- node.last_access_time = time.monotonic()
301
-
302
- child_key = self.get_child_key_fn(key)
303
-
304
- value = []
305
- while len(key) > 0 and child_key in node.children.keys():
306
- child = node.children[child_key]
307
- child.last_access_time = time.monotonic()
308
- prefix_len = self.key_match_fn(child.key, key)
309
- if prefix_len < len(child.key):
310
- new_node = self._split_node(child.key, child, prefix_len)
311
- value.append(new_node.value)
312
- node = new_node
313
- break
314
- else:
315
- value.append(child.value)
316
- node = child
317
- key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
318
-
319
- if len(key):
320
- child_key = self.get_child_key_fn(key)
321
-
322
- return value, node
323
-
324
- def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
325
- # new_node -> child
326
- new_node = LoRATreeNode()
327
- key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
328
- key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
329
- new_node.children = {self.get_child_key_fn(key_split_2): child}
330
- new_node.parent = child.parent
331
- new_node.lock_ref = child.lock_ref
332
- new_node.key = key_split_1
333
- new_node.value = child.value[:split_len]
334
- child.parent = new_node
335
- child.key = key_split_2
336
- child.value = child.value[split_len:]
337
- new_node.parent.children[self.get_child_key_fn(key)] = new_node
338
-
339
- return new_node
340
-
341
- def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
342
- node.last_access_time = time.monotonic()
343
- if len(key) == 0:
344
- return 0
345
-
346
- child_key = self.get_child_key_fn(key)
347
-
348
- total_prefix_length = 0
349
- while len(key) > 0 and child_key in node.children.keys():
350
- node = node.children[child_key]
351
- node.last_access_time = time.monotonic()
352
- prefix_len = self.key_match_fn(node.key, key)
353
- total_prefix_length += prefix_len
354
- key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
355
- value = value[prefix_len:]
356
-
357
- if prefix_len < len(node.key):
358
- new_node = self._split_node(node.key, node, prefix_len)
359
- node = new_node
360
-
361
- if len(key):
362
- child_key = self.get_child_key_fn(key)
363
-
364
- if len(key):
365
- new_node = LoRATreeNode()
366
- new_node.parent = node
367
- new_node.key = key
368
- new_node.value = value
369
- node.children[child_key] = new_node
370
- self.evictable_size_ += len(value)
371
- return total_prefix_length
372
-
373
- def _print_helper(self, node: LoRATreeNode, indent: int):
374
- """Prints the radix tree in a human-readable format."""
375
- stack = [(node, indent)]
376
- while stack:
377
- current_node, current_indent = stack.pop()
378
- print(
379
- " " * current_indent,
380
- len(current_node.key),
381
- current_node.key.token_ids[:10],
382
- f"r={current_node.lock_ref}",
383
- )
384
- for key, child in current_node.children.items():
385
- stack.append((child, current_indent + 2))
386
-
387
- assert key == self.get_child_key_fn(
388
- child.key
389
- ), f"{key=}, {self.get_child_key_fn(child.key)=}"
390
-
391
- def _delete_leaf(self, node):
392
- for k, v in node.parent.children.items():
393
- if v == node:
394
- break
395
- del node.parent.children[k]
396
- self.evictable_size_ -= len(node.key)
397
-
398
- def _total_size_helper(self):
399
- total_size = 0
400
- stack = [self.root_node]
401
- while stack:
402
- current_node = stack.pop()
403
- total_size += len(current_node.value)
404
- for child in current_node.children.values():
405
- if child.evicted:
406
- continue
407
- stack.append(child)
408
- return total_size
409
-
410
- def _collect_leaves(self):
411
- ret_list = []
412
- stack = [self.root_node]
413
-
414
- while stack:
415
- cur_node = stack.pop()
416
- if len(cur_node.children) == 0:
417
- ret_list.append(cur_node)
418
- else:
419
- stack.extend(cur_node.children.values())
420
-
421
- return ret_list
@@ -1,40 +0,0 @@
1
- import torch
2
- from mooncake_store import MooncakeStore
3
-
4
-
5
- def test_init_and_warmup():
6
- store = MooncakeStore()
7
- assert store.store is not None
8
-
9
-
10
- def test_register_buffer():
11
- store = MooncakeStore()
12
- tensor = torch.zeros(1024, dtype=torch.float32)
13
- store.register_buffer(tensor)
14
-
15
-
16
- def test_set_and_get():
17
- store = MooncakeStore()
18
-
19
- key = ["test_key_" + str(i) for i in range(2)]
20
- tensor = torch.arange(256, dtype=torch.float32).cuda()
21
- ptrs = [tensor.data_ptr(), tensor.data_ptr()]
22
- sizes = [tensor.numel() * tensor.element_size()] * 2
23
-
24
- store.set(key, target_location=ptrs, target_sizes=sizes)
25
- store.get(key, target_location=ptrs, target_sizes=sizes)
26
-
27
-
28
- def test_exists():
29
- store = MooncakeStore()
30
- keys = ["test_key_0", "non_existent_key"]
31
- result = store.exists(keys)
32
- assert isinstance(result, dict)
33
- assert "test_key_0" in result
34
-
35
-
36
- if __name__ == "__main__":
37
- test_init_and_warmup()
38
- test_register_buffer()
39
- test_set_and_get()
40
- test_exists()
File without changes