sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,354 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ if TYPE_CHECKING:
8
+ from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
9
+
10
+ """
11
+ k: data, 128 item per token, fp8
12
+ s: scale, 1 item per token, fp32
13
+ """
14
+
15
+
16
+ class GetK:
17
+ @classmethod
18
+ def execute(cls, *args, **kwargs):
19
+ return cls.torch_fast(*args, **kwargs)
20
+
21
+ @classmethod
22
+ def slow(
23
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
24
+ ):
25
+ num_pages = (seq_len + pool.page_size - 1) // pool.page_size
26
+ seq_len_ = num_pages * pool.page_size
27
+ index_k_fp8 = torch.empty(
28
+ (seq_len_, pool.index_head_dim),
29
+ dtype=torch.uint8,
30
+ device=pool.device,
31
+ )
32
+ for i in range(num_pages):
33
+ page_index = page_indices[i]
34
+ index_k_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
35
+ page_index
36
+ ][: pool.page_size * pool.index_head_dim].view(-1, pool.index_head_dim)
37
+
38
+ return index_k_fp8[:seq_len]
39
+
40
+ @classmethod
41
+ def torch_fast(
42
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
43
+ ):
44
+ """
45
+ :param page_indices: (num_pages,), int32
46
+ :return: (seq_len, index_head_dim), uint8
47
+ """
48
+
49
+ # can handle per 128B instead of per element
50
+
51
+ # page_indices: (num_pages,), element := a page index
52
+ buf_numel_per_page = buf.shape[1]
53
+
54
+ num_k_bytes_per_page = pool.page_size * pool.index_head_dim
55
+ num_k_bytes_per_token = pool.index_head_dim
56
+
57
+ # buf: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4), uint8
58
+ # flat_buf: (whatever,), uint8
59
+ flat_buf = buf.flatten()
60
+
61
+ # flat_indices: (num_pages, num_k_bytes_per_page), int32, element := an index into flat_buf that we want to access
62
+ flat_indices = (page_indices * buf_numel_per_page)[:, None] + torch.arange(
63
+ num_k_bytes_per_page, dtype=torch.int32, device="cuda"
64
+ )[None, :]
65
+ flat_indices = flat_indices.flatten()[: seq_len * num_k_bytes_per_token]
66
+
67
+ out = flat_buf[flat_indices]
68
+ return out.view(-1, 128)
69
+
70
+
71
+ class GetS:
72
+ @classmethod
73
+ def execute(cls, *args, **kwargs):
74
+ return cls.torch_fast(*args, **kwargs)
75
+
76
+ @classmethod
77
+ def slow(
78
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
79
+ ):
80
+ num_pages = (seq_len + pool.page_size - 1) // pool.page_size
81
+ seq_len_ = num_pages * pool.page_size
82
+ assert pool.index_head_dim // pool.quant_block_size == 1
83
+ index_k_scale_fp8 = torch.empty(
84
+ (seq_len_, 4),
85
+ dtype=torch.uint8,
86
+ device=pool.device,
87
+ )
88
+ for i in range(num_pages):
89
+ page_index = page_indices[i]
90
+ index_k_scale_fp8[i * pool.page_size : (i + 1) * pool.page_size] = buf[
91
+ page_index
92
+ ][pool.page_size * pool.index_head_dim :].view(-1, 4)
93
+ return index_k_scale_fp8[:seq_len]
94
+
95
+ @classmethod
96
+ def torch_fast(
97
+ cls, pool: "NSATokenToKVPool", buf, seq_len: int, page_indices: torch.Tensor
98
+ ):
99
+ """
100
+ :param page_indices: (num_pages,), int32
101
+ :return: (seq_len, index_head_dim // quant_block_size), uint8
102
+ """
103
+ buf_numel_per_page = buf.shape[1]
104
+
105
+ num_s_bytes_per_page = buf.shape[1] - pool.page_size * pool.index_head_dim
106
+ num_s_bytes_per_token = pool.index_head_dim // pool.quant_block_size * 4
107
+ s_offset_in_page = pool.page_size * pool.index_head_dim
108
+
109
+ flat_buf = buf.flatten()
110
+ flat_indices = (
111
+ (page_indices * buf_numel_per_page)[:, None]
112
+ + torch.arange(num_s_bytes_per_page, dtype=torch.int32, device="cuda")[
113
+ None, :
114
+ ]
115
+ + s_offset_in_page
116
+ )
117
+ flat_indices = flat_indices.flatten()[: seq_len * num_s_bytes_per_token]
118
+
119
+ out = flat_buf[flat_indices]
120
+ return out.view(-1, 4)
121
+
122
+
123
+ class SetK:
124
+ @classmethod
125
+ def execute(cls, *args, buf, **kwargs):
126
+ return cls.torch_fast(*args, **kwargs, buf=buf)
127
+
128
+ @classmethod
129
+ def slow(
130
+ cls,
131
+ pool: "NSATokenToKVPool",
132
+ buf: torch.Tensor,
133
+ loc: torch.Tensor,
134
+ index_k: torch.Tensor,
135
+ ):
136
+ for i in range(len(loc)):
137
+ page_index = loc[i] // pool.page_size
138
+ offset = loc[i] % pool.page_size
139
+ buf[
140
+ page_index,
141
+ offset * pool.index_head_dim : (offset + 1) * pool.index_head_dim,
142
+ ] = index_k[i].view(torch.uint8)
143
+
144
+ @classmethod
145
+ def torch_fast(
146
+ cls,
147
+ pool: "NSATokenToKVPool",
148
+ buf: torch.Tensor,
149
+ loc: torch.Tensor,
150
+ index_k: torch.Tensor,
151
+ ):
152
+ (num_tokens_to_write,) = loc.shape
153
+ buf_numel_per_page = buf.shape[1]
154
+ num_k_bytes_per_token = pool.index_head_dim
155
+
156
+ # loc: (num_tokens_to_write,), int32, element := the token index to write to
157
+ loc_page_index = loc // pool.page_size
158
+ loc_token_offset_in_page = loc % pool.page_size
159
+
160
+ flat_buf = buf.flatten()
161
+ flat_indices = (
162
+ (loc_page_index * buf_numel_per_page)[:, None]
163
+ + (loc_token_offset_in_page * num_k_bytes_per_token)[:, None]
164
+ + torch.arange(num_k_bytes_per_token, dtype=torch.int32, device="cuda")[
165
+ None, :
166
+ ]
167
+ )
168
+ num_k_bytes_total = num_tokens_to_write * num_k_bytes_per_token
169
+ flat_indices = flat_indices.flatten()[:num_k_bytes_total]
170
+ flat_buf[flat_indices] = index_k.view(torch.uint8).flatten()
171
+
172
+
173
+ class SetS:
174
+ @classmethod
175
+ def execute(cls, *args, buf, **kwargs):
176
+ return cls.torch_fast(*args, **kwargs, buf=buf)
177
+
178
+ @classmethod
179
+ def slow(
180
+ cls,
181
+ pool: "NSATokenToKVPool",
182
+ buf: torch.Tensor,
183
+ loc: torch.Tensor,
184
+ index_k_scale: torch.Tensor,
185
+ ):
186
+ for i in range(len(loc)):
187
+ page_index = loc[i] // pool.page_size
188
+ offset = loc[i] % pool.page_size
189
+ start = pool.page_size * pool.index_head_dim
190
+ buf[page_index, start + offset * 4 : start + (offset + 1) * 4] = (
191
+ index_k_scale[i].view(torch.uint8)
192
+ )
193
+
194
+ @classmethod
195
+ def torch_fast(
196
+ cls,
197
+ pool: "NSATokenToKVPool",
198
+ buf: torch.Tensor,
199
+ loc: torch.Tensor,
200
+ index_k_scale: torch.Tensor,
201
+ ):
202
+ (num_tokens_to_write,) = loc.shape
203
+ buf_numel_per_page = buf.shape[1]
204
+ num_s_bytes_per_token = 4
205
+ s_offset_in_page = pool.page_size * pool.index_head_dim
206
+
207
+ # loc: (num_tokens_to_write,), int32, element := the token index to write to
208
+ loc_page_index = loc // pool.page_size
209
+ loc_token_offset_in_page = loc % pool.page_size
210
+
211
+ flat_buf = buf.flatten()
212
+ flat_indices = (
213
+ (loc_page_index * buf_numel_per_page)[:, None]
214
+ + s_offset_in_page
215
+ + (loc_token_offset_in_page * num_s_bytes_per_token)[:, None]
216
+ + torch.arange(num_s_bytes_per_token, dtype=torch.int32, device="cuda")[
217
+ None, :
218
+ ]
219
+ )
220
+ number_s_bytes_total = num_tokens_to_write * num_s_bytes_per_token
221
+ flat_indices = flat_indices.flatten()[:number_s_bytes_total]
222
+ flat_buf[flat_indices] = index_k_scale.view(torch.uint8).flatten()
223
+
224
+
225
+ class SetKAndS:
226
+ @classmethod
227
+ def execute(cls, *args, buf, **kwargs):
228
+ if 0:
229
+ # print("SetK, SetS comparison test")
230
+ buf_cloned = buf.clone()
231
+ cls.vanilla(*args, **kwargs, buf=buf)
232
+ cls.triton(*args, **kwargs, buf=buf_cloned)
233
+
234
+ def _clear_token_0(target):
235
+ target[0, :128] = target[0, 64 * 128 : 64 * 128 + 4] = 0
236
+
237
+ _clear_token_0(buf)
238
+ _clear_token_0(buf_cloned)
239
+
240
+ assert torch.all(
241
+ buf == buf_cloned
242
+ ), f"{buf=} {buf_cloned=} {kwargs['loc'].to_list()=}"
243
+ return
244
+
245
+ cls.triton(*args, **kwargs, buf=buf)
246
+
247
+ @classmethod
248
+ def vanilla(cls, pool, buf, loc, index_k, index_k_scale):
249
+ SetK.execute(pool=pool, buf=buf, loc=loc, index_k=index_k)
250
+ SetS.execute(pool=pool, buf=buf, loc=loc, index_k_scale=index_k_scale)
251
+
252
+ @classmethod
253
+ def triton(cls, pool, buf, loc, index_k, index_k_scale):
254
+ _set_k_and_s_triton(
255
+ buf=buf,
256
+ loc=loc,
257
+ index_k=index_k,
258
+ index_k_scale=index_k_scale,
259
+ page_size=pool.page_size,
260
+ )
261
+
262
+
263
+ def _set_k_and_s_triton(
264
+ buf: torch.Tensor,
265
+ loc: torch.Tensor,
266
+ index_k: torch.Tensor,
267
+ index_k_scale: torch.Tensor,
268
+ page_size: int,
269
+ ):
270
+ """
271
+ :param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8
272
+ :param loc: (num_tokens_to_write,), int, element := the token index to write to
273
+ :param index_k: (num_tokens_to_write, 128 elem), fp8
274
+ :param index_k_scale: (num_tokens_to_write, 1 elem), fp32
275
+ :return:
276
+ """
277
+ num_pages, buf_numel_per_page = buf.shape
278
+ (num_tokens_to_write,) = loc.shape
279
+ num_tokens_to_write_, index_head_dim = index_k.shape
280
+ num_tokens_to_write__, scale_dim = index_k_scale.shape
281
+ assert buf_numel_per_page == 64 * (128 + 4)
282
+ assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__
283
+ assert index_head_dim == 128
284
+ assert scale_dim == 1
285
+ assert page_size == 64
286
+
287
+ assert buf.dtype == torch.uint8
288
+ assert loc.dtype == torch.int64, f"{loc.dtype=}" # can be int32
289
+ assert index_k.dtype == torch.float8_e4m3fn
290
+ assert index_k_scale.dtype == torch.float32
291
+
292
+ assert buf.is_contiguous()
293
+ assert loc.is_contiguous()
294
+ assert index_k.is_contiguous()
295
+ assert index_k_scale.is_contiguous()
296
+
297
+ buf_fp8 = buf.view(torch.float8_e4m3fn)
298
+ buf_fp32 = buf.view(torch.float32)
299
+
300
+ _set_k_and_s_triton_kernel[(num_tokens_to_write,)](
301
+ buf_fp8,
302
+ buf_fp32,
303
+ loc,
304
+ index_k,
305
+ index_k_scale,
306
+ index_k.stride(0),
307
+ PAGE_SIZE=page_size,
308
+ BUF_NUMEL_PER_PAGE=buf_numel_per_page,
309
+ NUM_K_ELEMS_PER_TOKEN=index_head_dim,
310
+ S_OFFSET_NBYTES_IN_PAGE=page_size * index_head_dim,
311
+ )
312
+
313
+
314
+ @triton.jit
315
+ def _set_k_and_s_triton_kernel(
316
+ buf_fp8_ptr,
317
+ buf_fp32_ptr,
318
+ loc_ptr,
319
+ index_k_ptr,
320
+ index_k_scale_ptr,
321
+ index_k_ptr_stride_0,
322
+ PAGE_SIZE: tl.constexpr,
323
+ BUF_NUMEL_PER_PAGE: tl.constexpr,
324
+ NUM_K_ELEMS_PER_TOKEN: tl.constexpr,
325
+ S_OFFSET_NBYTES_IN_PAGE: tl.constexpr,
326
+ ):
327
+ token_id = tl.program_id(0)
328
+
329
+ loc = tl.load(loc_ptr + token_id)
330
+
331
+ in_k_offsets = token_id * index_k_ptr_stride_0 + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
332
+
333
+ # no need for `mask`, since we read 128B for k and 4B for scale, both pow of 2
334
+ k = tl.load(index_k_ptr + in_k_offsets)
335
+ k_scale = tl.load(index_k_scale_ptr + token_id)
336
+
337
+ loc_page_index = loc // PAGE_SIZE
338
+ loc_token_offset_in_page = loc % PAGE_SIZE
339
+
340
+ out_k_offsets = (
341
+ loc_page_index * BUF_NUMEL_PER_PAGE
342
+ + loc_token_offset_in_page * NUM_K_ELEMS_PER_TOKEN
343
+ + tl.arange(0, NUM_K_ELEMS_PER_TOKEN)
344
+ )
345
+
346
+ # "//4" b/c it is fp32 instead of uint8
347
+ out_s_offset = (
348
+ loc_page_index * BUF_NUMEL_PER_PAGE // 4
349
+ + S_OFFSET_NBYTES_IN_PAGE // 4
350
+ + loc_token_offset_in_page
351
+ )
352
+
353
+ tl.store(buf_fp8_ptr + out_k_offsets, k)
354
+ tl.store(buf_fp32_ptr + out_s_offset, k_scale)