sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,640 @@
1
+ # Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.attention.fla.op import exp
12
+ from sglang.srt.layers.attention.fla.utils import input_guard
13
+
14
+
15
+ @triton.heuristics(
16
+ {
17
+ "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
18
+ "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
19
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
20
+ }
21
+ )
22
+ @triton.jit(do_not_specialize=["T"])
23
+ def fused_recurrent_gated_delta_rule_fwd_kernel(
24
+ q,
25
+ k,
26
+ v,
27
+ g,
28
+ beta,
29
+ o,
30
+ h0,
31
+ ht,
32
+ cu_seqlens,
33
+ scale,
34
+ T,
35
+ B: tl.constexpr,
36
+ H: tl.constexpr,
37
+ HV: tl.constexpr,
38
+ K: tl.constexpr,
39
+ V: tl.constexpr,
40
+ BK: tl.constexpr,
41
+ BV: tl.constexpr,
42
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
43
+ STORE_FINAL_STATE: tl.constexpr, # whether to store final state
44
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
45
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
46
+ IS_VARLEN: tl.constexpr,
47
+ ):
48
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
49
+ i_n, i_hv = i_nh // HV, i_nh % HV
50
+ i_h = i_hv // (HV // H)
51
+ if IS_VARLEN:
52
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(
53
+ cu_seqlens + i_n + 1
54
+ ).to(tl.int64)
55
+ all = T
56
+ T = eos - bos
57
+ else:
58
+ bos, eos = i_n * T, i_n * T + T
59
+ all = B * T
60
+ o_k = i_k * BK + tl.arange(0, BK)
61
+ o_v = i_v * BV + tl.arange(0, BV)
62
+
63
+ p_q = q + (bos * H + i_h) * K + o_k
64
+ p_k = k + (bos * H + i_h) * K + o_k
65
+ p_v = v + (bos * HV + i_hv) * V + o_v
66
+ if IS_BETA_HEADWISE:
67
+ p_beta = beta + (bos * HV + i_hv) * V + o_v
68
+ else:
69
+ p_beta = beta + bos * HV + i_hv
70
+ p_g = g + bos * HV + i_hv
71
+ p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
72
+
73
+ mask_k = o_k < K
74
+ mask_v = o_v < V
75
+ mask_h = mask_k[:, None] & mask_v[None, :]
76
+
77
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
78
+ if USE_INITIAL_STATE:
79
+ p_h0 = h0 + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
80
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
81
+
82
+ for _ in range(0, T):
83
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
84
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
85
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
86
+ b_g = tl.load(p_g).to(tl.float32)
87
+
88
+ if USE_QK_L2NORM_IN_KERNEL:
89
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
90
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
91
+ b_q = b_q * scale
92
+ # [BK, BV]
93
+ b_h *= exp(b_g)
94
+ # [BV]
95
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
96
+ if IS_BETA_HEADWISE:
97
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
98
+ else:
99
+ b_beta = tl.load(p_beta).to(tl.float32)
100
+ b_v *= b_beta
101
+ # [BK, BV]
102
+ b_h += b_k[:, None] * b_v[None, :]
103
+ # [BV]
104
+ b_o = tl.sum(b_h * b_q[:, None], 0)
105
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
106
+
107
+ p_q += H * K
108
+ p_k += H * K
109
+ p_o += HV * V
110
+ p_v += HV * V
111
+ p_g += HV
112
+ p_beta += HV * (V if IS_BETA_HEADWISE else 1)
113
+
114
+ if STORE_FINAL_STATE:
115
+ p_ht = ht + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
116
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
117
+
118
+
119
+ def fused_recurrent_gated_delta_rule_fwd(
120
+ q: torch.Tensor,
121
+ k: torch.Tensor,
122
+ v: torch.Tensor,
123
+ g: torch.Tensor,
124
+ beta: torch.Tensor,
125
+ scale: float,
126
+ initial_state: torch.Tensor,
127
+ output_final_state: bool,
128
+ use_qk_l2norm_in_kernel: bool = False,
129
+ cu_seqlens: Optional[torch.LongTensor] = None,
130
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
131
+ B, T, H, K, V = *k.shape, v.shape[-1]
132
+ HV = v.shape[2]
133
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
134
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
135
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
136
+ assert NK == 1, "NK > 1 is not supported yet"
137
+ num_stages = 3
138
+ num_warps = 1
139
+
140
+ o = q.new_empty(NK, *v.shape)
141
+ if output_final_state:
142
+ final_state = q.new_empty(N, HV, K, V, dtype=torch.float32)
143
+ else:
144
+ final_state = None
145
+
146
+ grid = (NK, NV, N * HV)
147
+ fused_recurrent_gated_delta_rule_fwd_kernel[grid](
148
+ q=q,
149
+ k=k,
150
+ v=v,
151
+ g=g,
152
+ beta=beta,
153
+ o=o,
154
+ h0=initial_state,
155
+ ht=final_state,
156
+ cu_seqlens=cu_seqlens,
157
+ scale=scale,
158
+ T=T,
159
+ B=B,
160
+ H=H,
161
+ HV=HV,
162
+ K=K,
163
+ V=V,
164
+ BK=BK,
165
+ BV=BV,
166
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
167
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
168
+ num_warps=num_warps,
169
+ num_stages=num_stages,
170
+ )
171
+ o = o.squeeze(0)
172
+ return o, final_state
173
+
174
+
175
+ class FusedRecurrentFunction(torch.autograd.Function):
176
+
177
+ @staticmethod
178
+ @input_guard
179
+ def forward(
180
+ ctx,
181
+ q: torch.Tensor,
182
+ k: torch.Tensor,
183
+ v: torch.Tensor,
184
+ g: torch.Tensor,
185
+ beta: torch.Tensor,
186
+ scale: float,
187
+ initial_state: torch.Tensor,
188
+ output_final_state: bool,
189
+ cu_seqlens: Optional[torch.LongTensor] = None,
190
+ use_qk_l2norm_in_kernel: bool = False,
191
+ ):
192
+ o, final_state = fused_recurrent_gated_delta_rule_fwd(
193
+ q=q,
194
+ k=k,
195
+ v=v,
196
+ g=g,
197
+ beta=beta,
198
+ scale=scale,
199
+ initial_state=initial_state,
200
+ output_final_state=output_final_state,
201
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
202
+ cu_seqlens=cu_seqlens,
203
+ )
204
+
205
+ return o, final_state
206
+
207
+ @staticmethod
208
+ @input_guard
209
+ def backward(ctx, do, dht):
210
+ raise NotImplementedError(
211
+ "Backward pass is not implemented yet and we do not have plans to implement it "
212
+ "because we haven't figured out how to compute dg without materializing the full "
213
+ "hidden states for all time steps."
214
+ )
215
+
216
+
217
+ def fused_recurrent_gated_delta_rule(
218
+ q: torch.Tensor,
219
+ k: torch.Tensor,
220
+ v: torch.Tensor,
221
+ g: torch.Tensor,
222
+ beta: torch.Tensor = None,
223
+ scale: float = None,
224
+ initial_state: torch.Tensor = None,
225
+ output_final_state: bool = False,
226
+ cu_seqlens: Optional[torch.LongTensor] = None,
227
+ use_qk_l2norm_in_kernel: bool = False,
228
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
229
+ r"""
230
+ Args:
231
+ q (torch.Tensor):
232
+ queries of shape `[B, T, H, K]`.
233
+ k (torch.Tensor):
234
+ keys of shape `[B, T, H, K]`.
235
+ v (torch.Tensor):
236
+ values of shape `[B, T, HV, V]`.
237
+ GVA is applied if `HV > H`.
238
+ g (torch.Tensor):
239
+ g (decays) of shape `[B, T, HV]`.
240
+ beta (torch.Tensor):
241
+ betas of shape `[B, T, HV]`.
242
+ scale (Optional[int]):
243
+ Scale factor for the RetNet attention scores.
244
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
245
+ initial_state (Optional[torch.Tensor]):
246
+ Initial state of shape `[N, HV, K, V]` for `N` input sequences.
247
+ For equal-length input sequences, `N` equals the batch size `B`.
248
+ Default: `None`.
249
+ output_final_state (Optional[bool]):
250
+ Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`.
251
+ cu_seqlens (torch.LongTensor):
252
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
253
+ consistent with the FlashAttention API.
254
+ Returns:
255
+ o (torch.Tensor):
256
+ Outputs of shape `[B, T, HV, V]`.
257
+ final_state (torch.Tensor):
258
+ Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`.
259
+ Examples::
260
+ >>> import torch
261
+ >>> import torch.nn.functional as F
262
+ >>> from einops import rearrange
263
+ >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
264
+ # inputs with equal lengths
265
+ >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
266
+ >>> q = torch.randn(B, T, H, K, device='cuda')
267
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
268
+ >>> v = torch.randn(B, T, HV, V, device='cuda')
269
+ >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
270
+ >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
271
+ >>> h0 = torch.randn(B, HV, K, V, device='cuda')
272
+ >>> o, ht = fused_gated_recurrent_delta_rule(
273
+ q, k, v, g, beta,
274
+ initial_state=h0,
275
+ output_final_state=True
276
+ )
277
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
278
+ >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
279
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
280
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
281
+ >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
282
+ q, k, v, g, beta,
283
+ initial_state=h0,
284
+ output_final_state=True,
285
+ cu_seqlens=cu_seqlens
286
+ )
287
+ """
288
+ if cu_seqlens is not None:
289
+ if q.shape[0] != 1:
290
+ raise ValueError(
291
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
292
+ f"Please flatten variable-length inputs before processing."
293
+ )
294
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
295
+ raise ValueError(
296
+ f"The number of initial states is expected to be equal to the number of input sequences, "
297
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
298
+ )
299
+ if scale is None:
300
+ scale = k.shape[-1] ** -0.5
301
+ else:
302
+ assert scale > 0, "scale must be positive"
303
+ if beta is None:
304
+ beta = torch.ones_like(q[..., 0])
305
+ o, final_state = FusedRecurrentFunction.apply(
306
+ q,
307
+ k,
308
+ v,
309
+ g,
310
+ beta,
311
+ scale,
312
+ initial_state,
313
+ output_final_state,
314
+ cu_seqlens,
315
+ use_qk_l2norm_in_kernel,
316
+ )
317
+ return o, final_state
318
+
319
+
320
+ @triton.heuristics(
321
+ {
322
+ "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
323
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
324
+ "CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"]
325
+ is not None,
326
+ }
327
+ )
328
+ @triton.jit(do_not_specialize=["T"])
329
+ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
330
+ q,
331
+ k,
332
+ v,
333
+ g,
334
+ beta,
335
+ o,
336
+ h0_source,
337
+ h0_indices,
338
+ cu_seqlens,
339
+ scale,
340
+ intermediate_states_buffer,
341
+ cache_steps,
342
+ T,
343
+ B: tl.constexpr,
344
+ H: tl.constexpr,
345
+ HV: tl.constexpr,
346
+ K: tl.constexpr,
347
+ V: tl.constexpr,
348
+ BK: tl.constexpr,
349
+ BV: tl.constexpr,
350
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
351
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
352
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
353
+ IS_VARLEN: tl.constexpr,
354
+ DISABLE_STATE_UPDATE: tl.constexpr, # whether to disable final state update
355
+ DISABLE_OUTPUT_CALCULATION: tl.constexpr, # whether to disable output calculation
356
+ CACHE_INTERMEDIATE_STATES: tl.constexpr,
357
+ ):
358
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
359
+ i_n, i_hv = i_nh // HV, i_nh % HV
360
+ i_h = i_hv // (HV // H)
361
+ if IS_VARLEN:
362
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(
363
+ cu_seqlens + i_n + 1
364
+ ).to(tl.int64)
365
+ all = T
366
+ T = eos - bos
367
+ else:
368
+ bos, eos = i_n * T, i_n * T + T
369
+ all = B * T
370
+ o_k = i_k * BK + tl.arange(0, BK)
371
+ o_v = i_v * BV + tl.arange(0, BV)
372
+
373
+ p_q = q + (bos * H + i_h) * K + o_k
374
+ p_k = k + (bos * H + i_h) * K + o_k
375
+ p_v = v + (bos * HV + i_hv) * V + o_v
376
+ if IS_BETA_HEADWISE:
377
+ p_beta = beta + (bos * HV + i_hv) * V + o_v
378
+ else:
379
+ p_beta = beta + bos * HV + i_hv
380
+ p_g = g + bos * HV + i_hv
381
+ p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
382
+
383
+ mask_k = o_k < K
384
+ mask_v = o_v < V
385
+ mask_h = mask_k[:, None] & mask_v[None, :]
386
+
387
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
388
+ if USE_INITIAL_STATE:
389
+ idx = tl.load(h0_indices + i_n)
390
+ # Add bounds checking for idx
391
+ if idx >= 0: # Assuming negative indices are invalid
392
+ p_h0 = (
393
+ h0_source
394
+ + idx * HV * K * V
395
+ + i_hv * K * V
396
+ + o_k[:, None] * V
397
+ + o_v[None, :]
398
+ )
399
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
400
+
401
+ # Prepare intermediate state cache variables if enabled
402
+ cache_idx = -1
403
+ if CACHE_INTERMEDIATE_STATES:
404
+ cache_idx = tl.load(h0_indices + i_n)
405
+
406
+ step_idx = 0
407
+ for _ in range(0, T):
408
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
409
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
410
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
411
+ b_g = tl.load(p_g).to(tl.float32)
412
+
413
+ if USE_QK_L2NORM_IN_KERNEL:
414
+ b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
415
+ b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
416
+ b_q = b_q * scale
417
+ # [BK, BV]
418
+ b_h *= exp(b_g)
419
+ # [BV]
420
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
421
+ if IS_BETA_HEADWISE:
422
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
423
+ else:
424
+ b_beta = tl.load(p_beta).to(tl.float32)
425
+ b_v *= b_beta
426
+ # [BK, BV]
427
+ b_h += b_k[:, None] * b_v[None, :]
428
+ # [BV]
429
+ if not DISABLE_OUTPUT_CALCULATION:
430
+ b_o = tl.sum(b_h * b_q[:, None], 0)
431
+ # core attn output
432
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
433
+
434
+ # store intermediate states if enabled
435
+ if CACHE_INTERMEDIATE_STATES:
436
+ if cache_idx >= 0:
437
+ # Compute cache pointer for this step
438
+ step_offset = step_idx * HV * K * V
439
+ cache_ptr = (
440
+ intermediate_states_buffer
441
+ + cache_idx * cache_steps * HV * K * V
442
+ + step_offset
443
+ + i_hv * K * V
444
+ + o_k[:, None] * V
445
+ + o_v[None, :]
446
+ )
447
+ tl.store(cache_ptr, b_h.to(cache_ptr.dtype.element_ty), mask=mask_h)
448
+
449
+ step_idx += 1
450
+
451
+ p_q += H * K
452
+ p_k += H * K
453
+ p_o += HV * V
454
+ p_v += HV * V
455
+ p_g += HV
456
+ p_beta += HV * (V if IS_BETA_HEADWISE else 1)
457
+
458
+ # Store final state back to h0_source with bounds checking
459
+ # ssm states
460
+ if not DISABLE_STATE_UPDATE:
461
+ idx = tl.load(h0_indices + i_n)
462
+ if idx >= 0: # Add bounds checking
463
+ p_h0 = (
464
+ h0_source
465
+ + idx * HV * K * V
466
+ + i_hv * K * V
467
+ + o_k[:, None] * V
468
+ + o_v[None, :]
469
+ )
470
+ tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
471
+
472
+
473
+ def fused_recurrent_gated_delta_rule_update_fwd(
474
+ q: torch.Tensor,
475
+ k: torch.Tensor,
476
+ v: torch.Tensor,
477
+ g: torch.Tensor,
478
+ beta: torch.Tensor,
479
+ scale: float,
480
+ initial_state_source: torch.Tensor,
481
+ initial_state_indices: torch.Tensor,
482
+ use_qk_l2norm_in_kernel: bool = False,
483
+ cu_seqlens: Optional[torch.LongTensor] = None,
484
+ disable_state_update: bool = False,
485
+ disable_output_calculation: bool = False,
486
+ intermediate_states_buffer: Optional[torch.Tensor] = None,
487
+ cache_steps: Optional[int] = None,
488
+ ) -> torch.Tensor:
489
+ B, T, H, K, V = *k.shape, v.shape[-1]
490
+ HV = v.shape[2]
491
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
492
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
493
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
494
+ assert NK == 1, "NK > 1 is not supported yet"
495
+ num_stages = 3
496
+ num_warps = 1
497
+
498
+ if disable_output_calculation:
499
+ # When output calculation is disabled, allocate minimal tensor
500
+ o = q.new_empty(NK, 1, 1, 1, 1) # minimal allocation
501
+ else:
502
+ o = q.new_empty(NK, *v.shape)
503
+
504
+ grid = (NK, NV, N * HV)
505
+
506
+ fused_recurrent_gated_delta_rule_update_fwd_kernel[grid](
507
+ q=q,
508
+ k=k,
509
+ v=v,
510
+ g=g,
511
+ beta=beta,
512
+ o=o,
513
+ h0_source=initial_state_source,
514
+ h0_indices=initial_state_indices,
515
+ cu_seqlens=cu_seqlens,
516
+ scale=scale,
517
+ intermediate_states_buffer=intermediate_states_buffer,
518
+ cache_steps=0 if cache_steps is None else cache_steps,
519
+ T=T,
520
+ B=B,
521
+ H=H,
522
+ HV=HV,
523
+ K=K,
524
+ V=V,
525
+ BK=BK,
526
+ BV=BV,
527
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
528
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
529
+ DISABLE_STATE_UPDATE=disable_state_update,
530
+ DISABLE_OUTPUT_CALCULATION=disable_output_calculation,
531
+ num_warps=num_warps,
532
+ num_stages=num_stages,
533
+ )
534
+ o = o.squeeze(0)
535
+ return o
536
+
537
+
538
+ class FusedRecurrentUpdateFunction(torch.autograd.Function):
539
+
540
+ @staticmethod
541
+ @input_guard
542
+ def forward(
543
+ ctx,
544
+ q: torch.Tensor,
545
+ k: torch.Tensor,
546
+ v: torch.Tensor,
547
+ g: torch.Tensor,
548
+ beta: torch.Tensor,
549
+ scale: float,
550
+ initial_state_source: torch.Tensor,
551
+ initial_state_indices: torch.Tensor,
552
+ cu_seqlens: Optional[torch.LongTensor] = None,
553
+ use_qk_l2norm_in_kernel: bool = False,
554
+ disable_state_update: bool = False,
555
+ disable_output_calculation: bool = False,
556
+ intermediate_states_buffer: Optional[torch.Tensor] = None,
557
+ cache_steps: Optional[int] = None,
558
+ ):
559
+ o = fused_recurrent_gated_delta_rule_update_fwd(
560
+ q=q,
561
+ k=k,
562
+ v=v,
563
+ g=g,
564
+ beta=beta,
565
+ scale=scale,
566
+ initial_state_source=initial_state_source,
567
+ initial_state_indices=initial_state_indices,
568
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
569
+ cu_seqlens=cu_seqlens,
570
+ disable_state_update=disable_state_update,
571
+ disable_output_calculation=disable_output_calculation,
572
+ intermediate_states_buffer=intermediate_states_buffer,
573
+ cache_steps=cache_steps,
574
+ )
575
+
576
+ return o
577
+
578
+ @staticmethod
579
+ @input_guard
580
+ def backward(ctx, do, dht):
581
+ raise NotImplementedError(
582
+ "Backward pass is not implemented yet and we do not have plans to implement it "
583
+ "because we haven't figured out how to compute dg without materializing the full "
584
+ "hidden states for all time steps."
585
+ )
586
+
587
+
588
+ def fused_recurrent_gated_delta_rule_update(
589
+ q: torch.Tensor,
590
+ k: torch.Tensor,
591
+ v: torch.Tensor,
592
+ g: torch.Tensor,
593
+ beta: torch.Tensor = None,
594
+ scale: float = None,
595
+ initial_state_source: torch.Tensor = None,
596
+ initial_state_indices: torch.Tensor = None,
597
+ cu_seqlens: Optional[torch.LongTensor] = None,
598
+ use_qk_l2norm_in_kernel: bool = False,
599
+ disable_state_update: bool = False,
600
+ disable_output_calculation: bool = False,
601
+ intermediate_states_buffer: Optional[torch.Tensor] = None,
602
+ cache_steps: Optional[int] = None,
603
+ ) -> torch.Tensor:
604
+ if cu_seqlens is not None:
605
+ if q.shape[0] != 1:
606
+ raise ValueError(
607
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
608
+ f"Please flatten variable-length inputs before processing."
609
+ )
610
+ if (
611
+ initial_state_source is not None
612
+ and initial_state_indices.shape[0] != len(cu_seqlens) - 1
613
+ ):
614
+ raise ValueError(
615
+ f"The number of initial states is expected to be equal to the number of input sequences, "
616
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}."
617
+ )
618
+ if scale is None:
619
+ scale = k.shape[-1] ** -0.5
620
+ else:
621
+ assert scale > 0, "scale must be positive"
622
+ if beta is None:
623
+ beta = torch.ones_like(q[..., 0])
624
+ o = FusedRecurrentUpdateFunction.apply(
625
+ q,
626
+ k,
627
+ v,
628
+ g,
629
+ beta,
630
+ scale,
631
+ initial_state_source,
632
+ initial_state_indices,
633
+ cu_seqlens,
634
+ use_qk_l2norm_in_kernel,
635
+ disable_state_update,
636
+ disable_output_calculation,
637
+ intermediate_states_buffer,
638
+ cache_steps,
639
+ )
640
+ return o