sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,761 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+ from sglang.srt.custom_op import CustomOp
12
+ from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
13
+
14
+ if is_cuda():
15
+ import deep_gemm
16
+
17
+ from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
18
+ from sglang.srt.layers.dp_attention import get_attention_tp_group
19
+ from sglang.srt.layers.linear import ReplicatedLinear
20
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
21
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
+ from sglang.srt.layers.rotary_embedding import get_rope_wrapper
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
24
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
25
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
26
+
27
+ if TYPE_CHECKING:
28
+ from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool
29
+
30
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024 if is_cuda() else 0
31
+
32
+
33
+ class BaseIndexerMetadata(ABC):
34
+ @abstractmethod
35
+ def get_seqlens_int32(self) -> torch.Tensor:
36
+ """
37
+ Return: (batch_size,) int32 tensor
38
+ """
39
+
40
+ @abstractmethod
41
+ def get_page_table_64(self) -> torch.Tensor:
42
+ """
43
+ Return: (batch_size, num_blocks) int32, page table.
44
+ The page size of the table is 64.
45
+ """
46
+
47
+ @abstractmethod
48
+ def get_seqlens_expanded(self) -> torch.Tensor:
49
+ """
50
+ Return: (sum_extend_seq_len,) int32 tensor
51
+ """
52
+
53
+ @abstractmethod
54
+ def topk_transform(
55
+ self,
56
+ logits: torch.Tensor,
57
+ topk: int,
58
+ ) -> torch.Tensor:
59
+ """
60
+ Perform topk selection on the logits and possibly transform the result.
61
+
62
+ NOTE that attention backend may override this function to do some
63
+ transformation, which means the result of this topk_transform may not
64
+ be the topk indices of the input logits.
65
+
66
+ Return: Anything, since it will be passed to the attention backend
67
+ for further processing on sparse attention computation.
68
+ Don't assume it is the topk indices of the input logits.
69
+ """
70
+
71
+
72
+ def rotate_activation(x: torch.Tensor) -> torch.Tensor:
73
+ assert x.dtype == torch.bfloat16
74
+ from fast_hadamard_transform import hadamard_transform
75
+
76
+ hidden_size = x.size(-1)
77
+ assert (
78
+ hidden_size & (hidden_size - 1)
79
+ ) == 0, "Hidden size must be a power of 2 for Hadamard transform."
80
+ return hadamard_transform(x, scale=hidden_size**-0.5)
81
+
82
+
83
+ class V32LayerNorm(nn.Module):
84
+ """
85
+ Layer Normalization.
86
+ """
87
+
88
+ def __init__(self, dim: int, eps: float = 1e-6):
89
+ super().__init__()
90
+ self.dim = dim
91
+ self.eps = eps
92
+ self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
93
+ self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
94
+
95
+ def forward(self, x: torch.Tensor):
96
+ return F.layer_norm(
97
+ x.float(), (self.dim,), self.weight, self.bias, self.eps
98
+ ).type_as(x)
99
+
100
+
101
+ class Indexer(CustomOp):
102
+ def __init__(
103
+ self,
104
+ hidden_size: int,
105
+ index_n_heads: int,
106
+ index_head_dim: int,
107
+ rope_head_dim: int,
108
+ index_topk: int,
109
+ q_lora_rank: int,
110
+ max_position_embeddings: int,
111
+ rope_theta: float,
112
+ layer_id: int,
113
+ scale_fmt: Optional[str],
114
+ block_size: int = 128,
115
+ rope_scaling: Optional[Dict[str, Any]] = None,
116
+ prefix: str = "",
117
+ quant_config: Optional[QuantizationConfig] = None,
118
+ alt_stream: Optional[torch.cuda.Stream] = None,
119
+ ):
120
+ super().__init__()
121
+ self.hidden_size = hidden_size
122
+ self.n_heads = index_n_heads
123
+ self.head_dim = index_head_dim
124
+ self.rope_head_dim = rope_head_dim
125
+ self.index_topk = index_topk
126
+ self.q_lora_rank = q_lora_rank
127
+ self.layer_id = layer_id
128
+ self.alt_stream = alt_stream
129
+ if is_cuda():
130
+ self.sm_count = deep_gemm.get_num_sms()
131
+ self.half_device_sm_count = align(self.sm_count // 2, 8)
132
+
133
+ self.wq_b = ReplicatedLinear(
134
+ self.q_lora_rank,
135
+ self.n_heads * self.head_dim,
136
+ bias=False,
137
+ quant_config=quant_config,
138
+ prefix=add_prefix("wq_b", prefix),
139
+ )
140
+ self.wk = ReplicatedLinear(
141
+ self.hidden_size,
142
+ self.head_dim,
143
+ bias=False,
144
+ quant_config=quant_config,
145
+ prefix=add_prefix("wk", prefix),
146
+ )
147
+ self.k_norm = V32LayerNorm(self.head_dim)
148
+ # NOTE: weight_proj is not quantized
149
+ self.weights_proj = ReplicatedLinear(
150
+ self.hidden_size,
151
+ self.n_heads,
152
+ bias=False,
153
+ prefix=add_prefix("weights_proj", prefix),
154
+ )
155
+ self.rotary_emb = get_rope_wrapper(
156
+ rope_head_dim,
157
+ rotary_dim=rope_head_dim,
158
+ max_position=max_position_embeddings,
159
+ base=rope_theta, # type: ignore
160
+ rope_scaling=rope_scaling,
161
+ is_neox_style=False,
162
+ device=global_server_args_dict["device"],
163
+ )
164
+ self.block_size = block_size
165
+ self.scale_fmt = scale_fmt
166
+ self.softmax_scale = self.head_dim**-0.5
167
+
168
+ def _forward_fake(
169
+ self,
170
+ x: torch.Tensor,
171
+ q_lora: torch.Tensor,
172
+ positions: torch.Tensor,
173
+ forward_batch: ForwardBatch,
174
+ layer_id: int,
175
+ ):
176
+ bs = x.shape[0]
177
+ assert self.index_topk == 2048
178
+ ans = torch.arange(0, self.index_topk, dtype=torch.int32, device=x.device)[
179
+ None, ...
180
+ ].repeat(bs, 1)
181
+ if forward_batch.forward_mode.is_extend():
182
+ assert (
183
+ forward_batch.extend_seq_lens_cpu is not None
184
+ and forward_batch.seq_lens_cpu is not None
185
+ )
186
+ which = 0
187
+ for i, (kv_len, qo_len) in enumerate(
188
+ zip(
189
+ forward_batch.seq_lens_cpu.tolist(),
190
+ forward_batch.extend_seq_lens_cpu,
191
+ strict=True,
192
+ )
193
+ ):
194
+ for j in range(kv_len - qo_len, kv_len):
195
+ ans[which, j + 1 :] = -1
196
+ which += 1
197
+ assert which == ans.shape[0]
198
+ else:
199
+ assert forward_batch.seq_lens_cpu is not None
200
+ for i, seq_len in enumerate(forward_batch.seq_lens_cpu.tolist()):
201
+ ans[i, seq_len:] = -1
202
+
203
+ return ans
204
+
205
+ def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
206
+ weights, _ = self.weights_proj(x)
207
+ weights = weights * self.n_heads**-0.5
208
+ weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
209
+ return weights
210
+
211
+ def _get_q_k_bf16(
212
+ self,
213
+ q_lora: torch.Tensor,
214
+ x: torch.Tensor,
215
+ positions: torch.Tensor,
216
+ enable_dual_stream: bool,
217
+ ):
218
+
219
+ if enable_dual_stream:
220
+ current_stream = torch.cuda.current_stream()
221
+ self.alt_stream.wait_stream(current_stream)
222
+
223
+ with deep_gemm_wrapper.configure_deep_gemm_num_sms(
224
+ self.half_device_sm_count
225
+ ):
226
+ query, _ = self.wq_b(q_lora)
227
+ query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
228
+ q_rope, _ = torch.split(
229
+ query,
230
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
231
+ dim=-1,
232
+ )
233
+ with torch.cuda.stream(self.alt_stream):
234
+ # TODO we should also put DeepGEMM half SM here?
235
+ key, _ = self.wk(x)
236
+ key = self.k_norm(key)
237
+
238
+ k_rope, _ = torch.split(
239
+ key,
240
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
241
+ dim=-1,
242
+ )
243
+
244
+ current_stream.wait_stream(self.alt_stream)
245
+ else:
246
+ query, _ = self.wq_b(q_lora)
247
+ query = rearrange(query, "l (h d) -> l h d", d=self.head_dim)
248
+
249
+ q_rope, _ = torch.split(
250
+ query, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
251
+ )
252
+
253
+ key, _ = self.wk(x)
254
+ key = self.k_norm(key)
255
+ k_rope, _ = torch.split(
256
+ key, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1
257
+ )
258
+
259
+ q_rope, k_rope = self.rotary_emb(positions, q_rope, k_rope)
260
+
261
+ query[..., : self.rope_head_dim] = q_rope
262
+ key[..., : self.rope_head_dim] = k_rope
263
+
264
+ if enable_dual_stream:
265
+ current_stream = torch.cuda.current_stream()
266
+ self.alt_stream.wait_stream(current_stream)
267
+ query = rotate_activation(query)
268
+
269
+ with torch.cuda.stream(self.alt_stream):
270
+ key = rotate_activation(key)
271
+ current_stream.wait_stream(self.alt_stream)
272
+ else:
273
+ query = rotate_activation(query)
274
+ key = rotate_activation(key)
275
+
276
+ return query, key
277
+
278
+ def _get_topk_paged(
279
+ self,
280
+ forward_batch: ForwardBatch,
281
+ layer_id: int,
282
+ q_fp8: torch.Tensor,
283
+ weights: torch.Tensor,
284
+ metadata: BaseIndexerMetadata,
285
+ ) -> torch.Tensor:
286
+ if TYPE_CHECKING:
287
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
288
+
289
+ page_size = forward_batch.token_to_kv_pool.page_size
290
+ # NOTE(dark): blocksize = 64 is hardcoded in deep_gemm
291
+ assert page_size == 64, "only support page size 64"
292
+
293
+ # NOTE(dark): this support extend/decode/decode+graph
294
+ block_tables = metadata.get_page_table_64()
295
+
296
+ max_seq_len = block_tables.shape[1] * page_size
297
+ kv_cache_fp8 = forward_batch.token_to_kv_pool.get_index_k_with_scale_buffer(
298
+ layer_id=layer_id
299
+ )
300
+
301
+ blocksize = page_size
302
+ seqlens_32 = metadata.get_seqlens_int32()
303
+ # NOTE(dark): 132 is SM count on H200/B200, not magic number
304
+ schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(
305
+ seqlens_32, blocksize, self.sm_count
306
+ )
307
+
308
+ assert len(q_fp8.shape) == 3
309
+ q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now
310
+ assert len(kv_cache_fp8.shape) == 2
311
+ block_kv = 64
312
+ num_heads_kv = 1
313
+ head_dim_with_sf = 132
314
+ kv_cache_fp8 = kv_cache_fp8.view(
315
+ kv_cache_fp8.shape[0], block_kv, num_heads_kv, head_dim_with_sf
316
+ )
317
+ assert len(weights.shape) == 3
318
+ weights = weights.squeeze(2)
319
+
320
+ logits = deep_gemm.fp8_paged_mqa_logits(
321
+ q_fp8,
322
+ kv_cache_fp8,
323
+ weights,
324
+ seqlens_32,
325
+ block_tables,
326
+ schedule_metadata,
327
+ max_seq_len,
328
+ clean_logits=False,
329
+ )
330
+
331
+ # NOTE(dark): logits should be cleaned in topk_transform
332
+ topk_result = metadata.topk_transform(logits, self.index_topk)
333
+ return topk_result
334
+
335
+ def _get_topk_ragged(
336
+ self,
337
+ forward_batch: ForwardBatch,
338
+ layer_id: int,
339
+ q_fp8: torch.Tensor,
340
+ weights: torch.Tensor,
341
+ metadata: BaseIndexerMetadata,
342
+ ) -> torch.Tensor:
343
+ if TYPE_CHECKING:
344
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
345
+
346
+ page_size = forward_batch.token_to_kv_pool.page_size
347
+ assert page_size == 64, "only support page size 64"
348
+ assert len(weights.shape) == 3
349
+ weights = weights.squeeze(-1)
350
+ k_fp8_list = []
351
+ k_scale_list = []
352
+ ks_list = []
353
+ offset = 0
354
+
355
+ block_tables = metadata.get_page_table_64()
356
+
357
+ assert (
358
+ forward_batch.seq_lens_cpu is not None
359
+ and forward_batch.extend_seq_lens_cpu is not None
360
+ )
361
+
362
+ for i in range(forward_batch.batch_size):
363
+ seq_len = forward_batch.seq_lens_cpu[i].item()
364
+ assert isinstance(seq_len, int)
365
+ k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
366
+ layer_id,
367
+ seq_len,
368
+ block_tables[i],
369
+ )
370
+ k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
371
+ layer_id,
372
+ seq_len,
373
+ block_tables[i],
374
+ )
375
+ extend_seq_len = forward_batch.extend_seq_lens_cpu[i]
376
+ ks = torch.full((extend_seq_len,), offset, dtype=torch.int32, device="cuda")
377
+ k_fp8_list.append(k_fp8)
378
+ k_scale_list.append(k_scale)
379
+ ks_list.append(ks)
380
+ offset += extend_seq_len
381
+
382
+ k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
383
+ k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
384
+ kv_fp8 = (k_fp8, k_scale)
385
+ ks = torch.cat(ks_list, dim=0)
386
+ seq_lens_expanded = metadata.get_seqlens_expanded()
387
+ ke = ks + seq_lens_expanded
388
+
389
+ logits = deep_gemm.fp8_mqa_logits(
390
+ q_fp8,
391
+ kv_fp8,
392
+ weights,
393
+ ks,
394
+ ke,
395
+ clean_logits=False,
396
+ )
397
+
398
+ assert logits.shape[0] == len(seq_lens_expanded)
399
+ topk_result = metadata.topk_transform(logits, self.index_topk)
400
+
401
+ return topk_result
402
+
403
+ def forward_indexer_bs_1(
404
+ self,
405
+ q_fp8: torch.Tensor,
406
+ weights: torch.Tensor,
407
+ forward_batch: ForwardBatch,
408
+ topk: int,
409
+ layer_id: int,
410
+ ) -> Optional[torch.Tensor]:
411
+ if not is_npu():
412
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import fp8_index
413
+
414
+ page_size = forward_batch.token_to_kv_pool.page_size
415
+ assert page_size == 64, "only support page size 64"
416
+
417
+ assert len(weights.shape) == 3
418
+ weights = weights.squeeze(-1)
419
+
420
+ # logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
421
+ k_fp8_list = []
422
+ k_scale_list = []
423
+
424
+ topk_indices_list = []
425
+
426
+ block_tables = forward_batch.req_to_token_pool.req_to_token[
427
+ forward_batch.req_pool_indices, :
428
+ ]
429
+ strided_indices = torch.arange(
430
+ 0, block_tables.shape[-1], page_size, device="cuda"
431
+ )
432
+ block_tables = block_tables[:, strided_indices] // page_size
433
+
434
+ q_len_start = 0
435
+
436
+ for i in range(forward_batch.batch_size):
437
+ seq_len = forward_batch.seq_lens[i].item()
438
+ q_len = (
439
+ forward_batch.extend_seq_lens_cpu[i]
440
+ if forward_batch.forward_mode.is_extend()
441
+ else 1
442
+ )
443
+ q_len_end = q_len_start + q_len
444
+
445
+ q_fp8_partial = q_fp8[q_len_start:q_len_end]
446
+ q_fp8_partial = q_fp8_partial.unsqueeze(0).contiguous()
447
+
448
+ weights_partial = weights[q_len_start:q_len_end]
449
+ weights_partial = weights_partial.squeeze(-1).unsqueeze(0).contiguous()
450
+
451
+ k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
452
+ layer_id,
453
+ seq_len,
454
+ block_tables[i],
455
+ )
456
+ k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
457
+ layer_id,
458
+ seq_len,
459
+ block_tables[i],
460
+ )
461
+
462
+ k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous()
463
+ k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous()
464
+
465
+ index_score = fp8_index(
466
+ q_fp8_partial,
467
+ weights_partial,
468
+ k_fp8,
469
+ k_scale,
470
+ )
471
+ end_pos = seq_len
472
+ topk_indices = index_score.topk(min(topk, end_pos), dim=-1)[1].squeeze(0)
473
+
474
+ pad_len = align(topk_indices.shape[-1], 2048) - topk_indices.shape[-1]
475
+ topk_indices = torch.nn.functional.pad(
476
+ topk_indices, (0, pad_len), "constant", -1
477
+ )
478
+
479
+ topk_indices_list.append(topk_indices)
480
+
481
+ q_len_start = q_len_end
482
+
483
+ topk_indices = torch.cat(topk_indices_list, dim=0)
484
+
485
+ return topk_indices
486
+
487
+ def forward_indexer(
488
+ self,
489
+ q_fp8: torch.Tensor,
490
+ weights: torch.Tensor,
491
+ forward_batch: ForwardBatch,
492
+ topk: int,
493
+ layer_id: int,
494
+ ) -> Optional[torch.Tensor]:
495
+ return self.forward_indexer_bs_1(q_fp8, weights, forward_batch, topk, layer_id)
496
+
497
+ def _forward(
498
+ self,
499
+ x: torch.Tensor,
500
+ q_lora: torch.Tensor,
501
+ positions: torch.Tensor,
502
+ forward_batch: ForwardBatch,
503
+ layer_id: int,
504
+ ) -> Optional[torch.Tensor]:
505
+ if not is_npu():
506
+ from sglang.srt.layers.attention.nsa.tilelang_kernel import act_quant
507
+
508
+ if TYPE_CHECKING:
509
+ assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
510
+
511
+ metadata = forward_batch.attn_backend.get_indexer_metadata(
512
+ layer_id, forward_batch
513
+ )
514
+
515
+ enable_dual_stream = (
516
+ NSA_DUAL_STREAM
517
+ and self.alt_stream is not None
518
+ and get_is_capture_mode()
519
+ and q_lora.shape[0] > 0
520
+ and q_lora.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
521
+ )
522
+
523
+ # skip NSA if attention backend choose to skip this batch
524
+ if metadata is None:
525
+ return None
526
+
527
+ if not NSA_USE_REAL_INDEXER: # temporary
528
+ return self._forward_fake(x, q_lora, positions, forward_batch, layer_id)
529
+
530
+ query, key = self._get_q_k_bf16(q_lora, x, positions, enable_dual_stream)
531
+
532
+ if enable_dual_stream:
533
+ current_stream = torch.cuda.current_stream()
534
+ self.alt_stream.wait_stream(current_stream)
535
+
536
+ q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
537
+ with torch.cuda.stream(self.alt_stream):
538
+ k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
539
+ current_stream.wait_stream(self.alt_stream)
540
+ else:
541
+ q_fp8, q_scale = act_quant(query, self.block_size, self.scale_fmt)
542
+ k_fp8, k_scale = act_quant(key, self.block_size, self.scale_fmt)
543
+
544
+ # k_fp8: (seq_len, head_dim) fp8_e4m3fn
545
+ # k_buffer: (num_total_tokens + page_size, head_dim) fp8_e4m3fn
546
+ # k_scale: (seq_len, head_dim // block_size = 1) fp8_e4m3fn
547
+ # k_scale_cache: (num_total_tokens + page_size, head_dim // block_size = 1) fp8_e4m3fn
548
+ forward_batch.token_to_kv_pool.set_index_k_and_scale_buffer(
549
+ layer_id=layer_id,
550
+ loc=forward_batch.out_cache_loc,
551
+ index_k=k_fp8,
552
+ index_k_scale=k_scale,
553
+ )
554
+
555
+ weights = self._get_logits_head_gate(x, q_scale)
556
+
557
+ if is_cuda():
558
+ assert forward_batch.seq_lens_cpu is not None
559
+ if len(forward_batch.seq_lens_cpu) == 0:
560
+ # this seems b/c max-pad, no worries?
561
+ # if x.shape[0] != 0:
562
+ # print(
563
+ # "HACK: seq_lens empty but x not empty, hackily return all-invalid topk_result"
564
+ # )
565
+ return torch.full(
566
+ (x.shape[0], self.index_topk), -1, dtype=torch.int, device="cuda"
567
+ )
568
+
569
+ if forward_batch.forward_mode.is_decode_or_idle():
570
+ topk_result = self._get_topk_paged(
571
+ forward_batch, layer_id, q_fp8, weights, metadata
572
+ )
573
+ else:
574
+ topk_result = self._get_topk_ragged(
575
+ forward_batch, layer_id, q_fp8, weights, metadata
576
+ )
577
+ else:
578
+ topk_result = self.forward_indexer(
579
+ q_fp8.contiguous(),
580
+ weights,
581
+ forward_batch,
582
+ topk=self.index_topk,
583
+ layer_id=layer_id,
584
+ )
585
+
586
+ return topk_result
587
+
588
+ def forward_cuda(
589
+ self,
590
+ x: torch.Tensor,
591
+ q_lora: torch.Tensor,
592
+ positions: torch.Tensor,
593
+ forward_batch: ForwardBatch,
594
+ layer_id: int,
595
+ ) -> Optional[torch.Tensor]:
596
+ return self._forward(x, q_lora, positions, forward_batch, layer_id)
597
+
598
+ def forward_npu(
599
+ self,
600
+ x: torch.Tensor,
601
+ q_lora: torch.Tensor,
602
+ positions: torch.Tensor,
603
+ forward_batch: ForwardBatch,
604
+ layer_id: int,
605
+ ) -> torch.Tensor:
606
+ import custom_ops
607
+ import torch_npu
608
+
609
+ from sglang.srt.layers.dp_attention import (
610
+ get_attention_tp_rank,
611
+ get_attention_tp_size,
612
+ )
613
+ from sglang.srt.utils import get_bool_env_var
614
+
615
+ if forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int is None:
616
+ actual_seq_lengths_kv = forward_batch.attn_backend.forward_metadata.seq_lens
617
+ else:
618
+ actual_seq_lengths_kv = (
619
+ forward_batch.attn_backend.forward_metadata.seq_lens_cpu_int
620
+ )
621
+ enable_index_cp = (
622
+ get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
623
+ )
624
+ is_prefill = forward_batch.forward_mode.is_extend()
625
+
626
+ attention_tp_rank = get_attention_tp_rank()
627
+ attention_tp_size = get_attention_tp_size()
628
+
629
+ cos_sin = self.rotary_emb.cos_sin_cache[positions]
630
+ cos, sin = cos_sin.chunk(2, dim=-1)
631
+ cos = cos.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
632
+ sin = sin.repeat(1, 2).view(-1, 1, 1, self.rope_head_dim)
633
+ if is_prefill and enable_index_cp:
634
+ slice_length = cos.shape[0] // attention_tp_size
635
+ cos = cos[
636
+ slice_length
637
+ * attention_tp_rank : slice_length
638
+ * (attention_tp_rank + 1)
639
+ ]
640
+ sin = sin[
641
+ slice_length
642
+ * attention_tp_rank : slice_length
643
+ * (attention_tp_rank + 1)
644
+ ]
645
+
646
+ slot_mapping = forward_batch.out_cache_loc
647
+ block_table = forward_batch.attn_backend.forward_metadata.block_tables
648
+
649
+ bs = x.shape[0]
650
+
651
+ q = self.wq_b(q_lora)[0] # [bs, 1536] @ [1536, 64 * 128] = [bs, 64 * 128]
652
+ q = q.view(bs, self.n_heads, self.head_dim) # [bs, 64, 128]
653
+ q_pe, q_nope = torch.split(
654
+ q,
655
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
656
+ dim=-1,
657
+ ) # [bs, 64, 64 + 64]
658
+
659
+ q_pe = q_pe.view(bs, self.n_heads, 1, self.rope_head_dim)
660
+ q_pe = torch_npu.npu_interleave_rope(q_pe, cos, sin).view(
661
+ bs, self.n_heads, self.rope_head_dim
662
+ ) # [bs, n, d]
663
+ q = torch.cat([q_pe, q_nope], dim=-1)
664
+
665
+ k_proj = self.wk(x)[0] # [b, s, 7168] @ [7168, 128] = [b, s, 128]
666
+ k = self.k_norm(k_proj)
667
+ k_pe, k_nope = torch.split(
668
+ k,
669
+ [self.rope_head_dim, self.head_dim - self.rope_head_dim],
670
+ dim=-1,
671
+ ) # [bs, 64 + 64]
672
+
673
+ k_pe = k_pe.view(-1, 1, 1, self.rope_head_dim)
674
+ k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin).view(
675
+ bs, 1, self.rope_head_dim
676
+ ) # [bs, 1, d]
677
+ k = torch.cat([k_pe, k_nope.unsqueeze(1)], dim=-1) # [bs, 1, 128]
678
+
679
+ if is_prefill and enable_index_cp:
680
+ k, local_k = (
681
+ torch.empty(
682
+ (k.shape[0] * attention_tp_size, k.shape[1], k.shape[2]),
683
+ dtype=k.dtype,
684
+ device=k.device,
685
+ ),
686
+ k,
687
+ )
688
+ get_attention_tp_group().all_gather_into_tensor(k, local_k)
689
+
690
+ forward_batch.token_to_kv_pool.set_index_k_buffer(layer_id, slot_mapping, k)
691
+
692
+ indexer_input = {}
693
+ if is_prefill:
694
+ actual_seq_lengths_kv = forward_batch.seq_lens.to(device=q.device)
695
+ actual_seq_lengths_q = forward_batch.seq_lens.cumsum(dim=0).to(
696
+ device=q.device
697
+ )
698
+ if enable_index_cp:
699
+ actual_seq_lengths_q -= bs * attention_tp_rank
700
+ actual_seq_lengths_q = torch.max(
701
+ actual_seq_lengths_q,
702
+ torch.zeros_like(actual_seq_lengths_q).to(
703
+ device=actual_seq_lengths_q.device
704
+ ),
705
+ )
706
+ actual_seq_lengths_q = torch.min(
707
+ actual_seq_lengths_q,
708
+ torch.full(actual_seq_lengths_q.shape, bs).to(
709
+ device=actual_seq_lengths_q.device
710
+ ),
711
+ )
712
+
713
+ else:
714
+ if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
715
+ actual_seq_lengths_q = torch.tensor(
716
+ [1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
717
+ )
718
+ else:
719
+ actual_seq_lengths_q = (
720
+ forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
721
+ )
722
+
723
+ past_key_states = forward_batch.token_to_kv_pool.get_index_k_buffer(layer_id)
724
+
725
+ x = x.view(-1, self.hidden_size)
726
+ weights = self.weights_proj(x)[0]
727
+ block_table = (
728
+ block_table[: actual_seq_lengths_q.size()[0]] if is_prefill else block_table
729
+ )
730
+
731
+ topk_indices = torch.ops.custom.npu_lightning_indexer(
732
+ query=q.view(-1, self.n_heads, self.head_dim),
733
+ key=past_key_states,
734
+ weights=weights,
735
+ actual_seq_lengths_query=actual_seq_lengths_q.to(torch.int32),
736
+ actual_seq_lengths_key=actual_seq_lengths_kv.to(k.device).to(torch.int32),
737
+ block_table=block_table,
738
+ layout_query="TND",
739
+ layout_key="PA_BSND",
740
+ sparse_count=self.index_topk,
741
+ sparse_mode=3,
742
+ )
743
+
744
+ if is_prefill and enable_index_cp:
745
+ topk_indices, local_topk_indices = (
746
+ torch.empty(
747
+ (
748
+ topk_indices.shape[0] * attention_tp_size,
749
+ topk_indices.shape[1],
750
+ topk_indices.shape[2],
751
+ ),
752
+ dtype=topk_indices.dtype,
753
+ device=topk_indices.device,
754
+ ),
755
+ topk_indices,
756
+ )
757
+ get_attention_tp_group().all_gather_into_tensor(
758
+ topk_indices, local_topk_indices
759
+ )
760
+
761
+ return topk_indices