sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,151 @@
1
+ import logging
2
+ from typing import Any, List, Optional
3
+
4
+ import torch
5
+ from aibrix_kvcache import (
6
+ BaseKVCacheManager,
7
+ BlockHashes,
8
+ KVCacheBlockLayout,
9
+ KVCacheBlockSpec,
10
+ KVCacheConfig,
11
+ KVCacheTensorSpec,
12
+ ModelSpec,
13
+ )
14
+ from aibrix_kvcache.common.absl_logging import log_every_n_seconds
15
+
16
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
17
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class AibrixKVCacheStorage(HiCacheStorage):
23
+ def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
24
+ if storage_config is not None:
25
+ self.is_mla_backend = storage_config.is_mla_model
26
+ self.local_rank = storage_config.tp_rank
27
+ else:
28
+ self.is_mla_backend = False
29
+ self.local_rank = 0
30
+ kv_cache = mem_pool.device_pool
31
+ self.page_size = mem_pool.page_size
32
+ self.kv_cache_dtype = kv_cache.dtype
33
+ self.layer_num = kv_cache.layer_num
34
+ self.kv_head_ids = [
35
+ self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
36
+ ]
37
+ if not self.is_mla_backend:
38
+ self.layer_ids = range(
39
+ kv_cache.start_layer, kv_cache.end_layer
40
+ ) # for pipeline parallel
41
+
42
+ self.block_spec = KVCacheBlockSpec(
43
+ block_ntokens=self.page_size,
44
+ block_dtype=self.kv_cache_dtype,
45
+ block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
46
+ tensor_spec=KVCacheTensorSpec(
47
+ heads=self.kv_head_ids,
48
+ layers=self.layer_ids,
49
+ head_size=kv_cache.head_dim,
50
+ ),
51
+ )
52
+ logger.info(self.block_spec)
53
+ config = KVCacheConfig(
54
+ block_spec=self.block_spec, model_spec=ModelSpec(102400)
55
+ )
56
+ self.kv_cache_manager = BaseKVCacheManager(config)
57
+ else:
58
+ raise NotImplementedError(
59
+ "MLA is not supported by AibrixKVCacheStorage yet."
60
+ )
61
+
62
+ def _aibrix_kvcache_metrics_report(self):
63
+ self.kv_cache_manager.metrics.summary()
64
+ self.kv_cache_manager.metrics.reset()
65
+
66
+ def batch_get(
67
+ self,
68
+ keys: List[str],
69
+ target_locations: List[torch.Tensor],
70
+ target_sizes: Optional[Any] = None,
71
+ ) -> List[torch.Tensor | None]:
72
+ block_hash = BlockHashes(keys, self.page_size)
73
+ status = self.kv_cache_manager.acquire(None, block_hash)
74
+ log_every_n_seconds(
75
+ logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
76
+ )
77
+ if status.is_ok():
78
+ num_fetched_tokens, handle = status.value
79
+ kv_blocks = handle.to_tensors()
80
+ assert len(kv_blocks) == len(target_locations)
81
+ for i in range(len(kv_blocks)):
82
+ assert (
83
+ target_locations[i].nbytes == kv_blocks[i].nbytes
84
+ ), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
85
+ target_locations[i].copy_(kv_blocks[i].flatten())
86
+ handle.release()
87
+ return target_locations
88
+
89
+ return [None] * len(keys)
90
+
91
+ def get(
92
+ self,
93
+ key: str,
94
+ target_location: Optional[Any] = None,
95
+ target_size: Optional[Any] = None,
96
+ ) -> torch.Tensor | None:
97
+ return self.batch_get([key], [target_location], [target_size])[0]
98
+
99
+ def batch_set(
100
+ self,
101
+ keys: List[str],
102
+ values: Optional[Any] = None,
103
+ target_locations: Optional[Any] = None,
104
+ target_sizes: Optional[Any] = None,
105
+ ) -> bool:
106
+ block_hash = BlockHashes(keys, self.page_size)
107
+ status = self.kv_cache_manager.allocate_for(None, block_hash)
108
+ if not status.is_ok():
109
+ logger.warning(
110
+ f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
111
+ )
112
+ return False
113
+ handle = status.value
114
+ tensors = handle.to_tensors()
115
+ if len(tensors) != len(values):
116
+ logger.warning("aibrix_kvcache set allocate not enough")
117
+ return False
118
+ for i in range(len(tensors)):
119
+ assert (
120
+ tensors[i].nbytes == values[i].nbytes
121
+ ), f"{tensors[i].nbytes}, {values[i].nbytes}"
122
+ tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
123
+ tensors[i].shape
124
+ )
125
+ status = self.kv_cache_manager.put(None, block_hash, handle)
126
+ if not status.is_ok():
127
+ logger.info(
128
+ f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
129
+ )
130
+ return False
131
+ completed = status.value
132
+ return completed == len(keys) * self.page_size
133
+
134
+ def set(
135
+ self,
136
+ key: str,
137
+ value: Optional[Any] = None,
138
+ target_location: Optional[Any] = None,
139
+ target_size: Optional[Any] = None,
140
+ ) -> bool:
141
+ return self.batch_set([key], [value], [target_location], [target_size])
142
+
143
+ def batch_exists(self, keys: List[str]) -> int:
144
+ block_hash = BlockHashes(keys, self.page_size)
145
+ status = self.kv_cache_manager.exists(None, block_hash)
146
+ if status.is_ok():
147
+ return status.value // self.page_size
148
+ return 0
149
+
150
+ def exists(self, key: str) -> bool | dict:
151
+ return self.batch_exists([key]) > 0
@@ -0,0 +1,109 @@
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import torch.distributed
6
+ from aibrix_kvcache import (
7
+ BaseKVCacheManager,
8
+ GroupAwareKVCacheManager,
9
+ KVCacheBlockLayout,
10
+ KVCacheBlockSpec,
11
+ KVCacheConfig,
12
+ KVCacheMetrics,
13
+ KVCacheTensorSpec,
14
+ ModelSpec,
15
+ TokenListView,
16
+ )
17
+ from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
18
+ from aibrix_kvcache_storage import AibrixKVCacheStorage
19
+ from torch.distributed import Backend, ProcessGroup
20
+
21
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
22
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
23
+ from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
24
+
25
+ logging.basicConfig(
26
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def setup():
33
+ os.environ["RANK"] = "0"
34
+ os.environ["WORLD_SIZE"] = "1"
35
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
36
+ os.environ["MASTER_PORT"] = "63886"
37
+
38
+
39
+ class AIBrixKVCacheStorageTest:
40
+ def test_with_page_size(self):
41
+ config = HiCacheStorageConfig(
42
+ tp_rank=0,
43
+ tp_size=1,
44
+ is_mla_model=False,
45
+ is_page_first_layout=True,
46
+ model_name="test",
47
+ )
48
+ for page_size in range(1, 3):
49
+ logger.info(f"page_size: {page_size}")
50
+ batch_size = 2
51
+ head_num = 1
52
+ layer_num = 64
53
+ head_dim = 128
54
+ kv_cache = MHATokenToKVPool(
55
+ 1024,
56
+ page_size,
57
+ torch.float16,
58
+ head_num,
59
+ head_dim,
60
+ layer_num,
61
+ "cpu",
62
+ False,
63
+ 0,
64
+ layer_num,
65
+ )
66
+ mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
67
+ query_length = batch_size * 2
68
+ partial = batch_size
69
+ self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
70
+ target_shape = (2, layer_num, page_size, head_num, head_dim)
71
+ rand_tensor = [
72
+ torch.rand(target_shape, dtype=torch.float16)
73
+ for _ in range(query_length)
74
+ ]
75
+ keys = ["hash" + str(i) for i in range(query_length)]
76
+ partial_keys = keys[batch_size:query_length]
77
+ assert self.aibrix_kvcache.batch_exists(keys) == 0
78
+ assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
79
+ get_tensor = [
80
+ torch.rand(target_shape, dtype=torch.float16).flatten()
81
+ for _ in range(query_length)
82
+ ]
83
+ self.aibrix_kvcache.batch_get(keys, get_tensor)
84
+ for i in range(query_length):
85
+ assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
86
+ ret = self.aibrix_kvcache.batch_exists(keys)
87
+ assert self.aibrix_kvcache.batch_exists(keys) == query_length
88
+ assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
89
+ partial_get_tensor = [
90
+ torch.rand(target_shape, dtype=torch.float16).flatten()
91
+ for _ in range(partial)
92
+ ]
93
+ self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
94
+ for i in range(partial):
95
+ assert torch.equal(
96
+ partial_get_tensor[i], rand_tensor[i + partial].flatten()
97
+ )
98
+ log_every_n_seconds(
99
+ logger,
100
+ logging.INFO,
101
+ self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
102
+ 1,
103
+ )
104
+
105
+
106
+ if __name__ == "__main__":
107
+ setup()
108
+ test = AIBrixKVCacheStorageTest()
109
+ test.test_with_page_size()
@@ -0,0 +1,223 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to SGLang project
3
+
4
+ import importlib
5
+ import logging
6
+ from typing import TYPE_CHECKING, Any, Dict
7
+
8
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
9
+
10
+ if TYPE_CHECKING:
11
+ pass
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class StorageBackendFactory:
17
+ """Factory for creating storage backend instances with support for dynamic loading."""
18
+
19
+ _registry: Dict[str, Dict[str, Any]] = {}
20
+
21
+ @staticmethod
22
+ def _load_backend_class(
23
+ module_path: str, class_name: str, backend_name: str
24
+ ) -> type[HiCacheStorage]:
25
+ """Load and validate a backend class from module path."""
26
+ try:
27
+ module = importlib.import_module(module_path)
28
+ backend_class = getattr(module, class_name)
29
+ if not issubclass(backend_class, HiCacheStorage):
30
+ raise TypeError(
31
+ f"Backend class {class_name} must inherit from HiCacheStorage"
32
+ )
33
+ return backend_class
34
+ except ImportError as e:
35
+ raise ImportError(
36
+ f"Failed to import backend '{backend_name}' from '{module_path}': {e}"
37
+ ) from e
38
+ except AttributeError as e:
39
+ raise AttributeError(
40
+ f"Class '{class_name}' not found in module '{module_path}': {e}"
41
+ ) from e
42
+
43
+ @classmethod
44
+ def register_backend(cls, name: str, module_path: str, class_name: str) -> None:
45
+ """Register a storage backend with lazy loading.
46
+
47
+ Args:
48
+ name: Backend identifier
49
+ module_path: Python module path containing the backend class
50
+ class_name: Name of the backend class
51
+ """
52
+ if name in cls._registry:
53
+ logger.warning(f"Backend '{name}' is already registered, overwriting")
54
+
55
+ def loader() -> type[HiCacheStorage]:
56
+ """Lazy loader function to import the backend class."""
57
+ return cls._load_backend_class(module_path, class_name, name)
58
+
59
+ cls._registry[name] = {
60
+ "loader": loader,
61
+ "module_path": module_path,
62
+ "class_name": class_name,
63
+ }
64
+
65
+ @classmethod
66
+ def create_backend(
67
+ cls,
68
+ backend_name: str,
69
+ storage_config: HiCacheStorageConfig,
70
+ mem_pool_host: Any,
71
+ **kwargs,
72
+ ) -> HiCacheStorage:
73
+ """Create a storage backend instance.
74
+ Args:
75
+ backend_name: Name of the backend to create
76
+ storage_config: Storage configuration
77
+ mem_pool_host: Memory pool host object
78
+ **kwargs: Additional arguments passed to external backends
79
+ Returns:
80
+ Initialized storage backend instance
81
+ Raises:
82
+ ValueError: If backend is not registered and cannot be dynamically loaded
83
+ ImportError: If backend module cannot be imported
84
+ Exception: If backend initialization fails
85
+ """
86
+ # First check if backend is already registered
87
+ if backend_name in cls._registry:
88
+ registry_entry = cls._registry[backend_name]
89
+ backend_class = registry_entry["loader"]()
90
+ logger.info(
91
+ f"Creating storage backend '{backend_name}' "
92
+ f"({registry_entry['module_path']}.{registry_entry['class_name']})"
93
+ )
94
+ return cls._create_builtin_backend(
95
+ backend_name, backend_class, storage_config, mem_pool_host
96
+ )
97
+
98
+ # Try to dynamically load backend from extra_config
99
+ if backend_name == "dynamic" and storage_config.extra_config is not None:
100
+ backend_config = storage_config.extra_config
101
+ return cls._create_dynamic_backend(
102
+ backend_config, storage_config, mem_pool_host, **kwargs
103
+ )
104
+
105
+ # Backend not found
106
+ available_backends = list(cls._registry.keys())
107
+
108
+ raise ValueError(
109
+ f"Unknown storage backend '{backend_name}'. "
110
+ f"Registered backends: {available_backends}. "
111
+ )
112
+
113
+ @classmethod
114
+ def _create_dynamic_backend(
115
+ cls,
116
+ backend_config: Dict[str, Any],
117
+ storage_config: HiCacheStorageConfig,
118
+ mem_pool_host: Any,
119
+ **kwargs,
120
+ ) -> HiCacheStorage:
121
+ """Create a backend dynamically from configuration."""
122
+ required_fields = ["backend_name", "module_path", "class_name"]
123
+ for field in required_fields:
124
+ if field not in backend_config:
125
+ raise ValueError(
126
+ f"Missing required field '{field}' in backend config for 'dynamic' backend"
127
+ )
128
+
129
+ backend_name = backend_config["backend_name"]
130
+ module_path = backend_config["module_path"]
131
+ class_name = backend_config["class_name"]
132
+
133
+ try:
134
+ # Import the backend class
135
+ backend_class = cls._load_backend_class(
136
+ module_path, class_name, backend_name
137
+ )
138
+
139
+ logger.info(
140
+ f"Creating dynamic storage backend '{backend_name}' "
141
+ f"({module_path}.{class_name})"
142
+ )
143
+
144
+ # Create the backend instance with storage_config
145
+ return backend_class(storage_config, kwargs)
146
+ except Exception as e:
147
+ logger.error(
148
+ f"Failed to create dynamic storage backend '{backend_name}': {e}"
149
+ )
150
+ raise
151
+
152
+ @classmethod
153
+ def _create_builtin_backend(
154
+ cls,
155
+ backend_name: str,
156
+ backend_class: type[HiCacheStorage],
157
+ storage_config: HiCacheStorageConfig,
158
+ mem_pool_host: Any,
159
+ ) -> HiCacheStorage:
160
+ """Create built-in backend with original initialization logic."""
161
+ if backend_name == "file":
162
+ return backend_class(storage_config)
163
+ elif backend_name == "nixl":
164
+ return backend_class()
165
+ elif backend_name == "mooncake":
166
+ backend = backend_class(storage_config)
167
+ return backend
168
+ elif backend_name == "aibrix":
169
+ backend = backend_class(storage_config, mem_pool_host)
170
+ return backend
171
+ elif backend_name == "hf3fs":
172
+ # Calculate bytes_per_page based on memory pool layout
173
+ if mem_pool_host.layout == "page_first":
174
+ bytes_per_page = (
175
+ mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
176
+ )
177
+ elif mem_pool_host.layout == "layer_first":
178
+ bytes_per_page = (
179
+ mem_pool_host.get_size_per_token() * mem_pool_host.page_size
180
+ )
181
+
182
+ dtype = mem_pool_host.dtype
183
+ return backend_class.from_env_config(bytes_per_page, dtype, storage_config)
184
+ elif backend_name == "eic":
185
+ return backend_class(storage_config, mem_pool_host)
186
+ else:
187
+ raise ValueError(f"Unknown built-in backend: {backend_name}")
188
+
189
+
190
+ # Register built-in storage backends
191
+ StorageBackendFactory.register_backend(
192
+ "file", "sglang.srt.mem_cache.hicache_storage", "HiCacheFile"
193
+ )
194
+
195
+ StorageBackendFactory.register_backend(
196
+ "nixl",
197
+ "sglang.srt.mem_cache.storage.nixl.hicache_nixl",
198
+ "HiCacheNixl",
199
+ )
200
+
201
+ StorageBackendFactory.register_backend(
202
+ "mooncake",
203
+ "sglang.srt.mem_cache.storage.mooncake_store.mooncake_store",
204
+ "MooncakeStore",
205
+ )
206
+
207
+ StorageBackendFactory.register_backend(
208
+ "hf3fs",
209
+ "sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs",
210
+ "HiCacheHF3FS",
211
+ )
212
+
213
+ StorageBackendFactory.register_backend(
214
+ "aibrix",
215
+ "sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage",
216
+ "AibrixKVCacheStorage",
217
+ )
218
+
219
+ StorageBackendFactory.register_backend(
220
+ "eic",
221
+ "sglang.srt.mem_cache.storage.eic.eic_storage",
222
+ "EICStorage",
223
+ )