sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1359 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/layers/fla/ops/kda.py
2
+ # This file contains code copied from the flash-linear-attention project.
3
+ # The original source code was licensed under the MIT license and included
4
+ # the following copyright notice:
5
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
13
+ from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
14
+ from sglang.srt.layers.attention.fla.fused_recurrent import (
15
+ fused_recurrent_gated_delta_rule_fwd_kernel,
16
+ )
17
+ from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
18
+ from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
19
+ from sglang.srt.layers.attention.fla.op import exp, log
20
+ from sglang.srt.layers.attention.fla.solve_tril import solve_tril
21
+ from sglang.srt.layers.attention.fla.utils import is_amd
22
+
23
+ BT_LIST_AUTOTUNE = [32, 64, 128]
24
+ NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
25
+
26
+
27
+ def cdiv(a: int, b: int) -> int:
28
+ """Ceiling division."""
29
+ return -(a // -b)
30
+
31
+
32
+ def next_power_of_2(n: int) -> int:
33
+ """The next power of 2 (inclusive)"""
34
+ if n < 1:
35
+ return 1
36
+ return 1 << (n - 1).bit_length()
37
+
38
+
39
+ def fused_recurrent_kda_fwd(
40
+ q: torch.Tensor,
41
+ k: torch.Tensor,
42
+ v: torch.Tensor,
43
+ g: torch.Tensor,
44
+ beta: torch.Tensor,
45
+ scale: float,
46
+ initial_state: torch.Tensor,
47
+ inplace_final_state: bool = True,
48
+ cu_seqlens: torch.LongTensor | None = None,
49
+ # ssm_state_indices: torch.Tensor | None = None,
50
+ num_accepted_tokens: torch.Tensor | None = None,
51
+ use_qk_l2norm_in_kernel: bool = False,
52
+ ) -> tuple[torch.Tensor, torch.Tensor]:
53
+ B, T, H, K, V = *k.shape, v.shape[-1]
54
+ HV = v.shape[2]
55
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
56
+ BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8)
57
+ NK, NV = cdiv(K, BK), cdiv(V, BV)
58
+ assert NK == 1, "NK > 1 is not supported yet"
59
+ num_stages = 3
60
+ num_warps = 1
61
+
62
+ o = torch.empty_like(k)
63
+ if inplace_final_state:
64
+ final_state = initial_state
65
+ else:
66
+ final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
67
+
68
+ stride_init_state_token = initial_state.stride(0)
69
+ stride_final_state_token = final_state.stride(0)
70
+
71
+ # if ssm_state_indices is None:
72
+ # stride_indices_seq, stride_indices_tok = 1, 1
73
+ # elif ssm_state_indices.ndim == 1:
74
+ # stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
75
+ # else:
76
+ # stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
77
+
78
+ grid = (NK, NV, N * HV)
79
+ fused_recurrent_gated_delta_rule_fwd_kernel[grid](
80
+ q=q,
81
+ k=k,
82
+ v=v,
83
+ g=g,
84
+ beta=beta,
85
+ o=o,
86
+ h0=initial_state,
87
+ ht=final_state,
88
+ cu_seqlens=cu_seqlens,
89
+ # ssm_state_indices=ssm_state_indices,
90
+ # num_accepted_tokens=num_accepted_tokens,
91
+ scale=scale,
92
+ # N=N,
93
+ T=T,
94
+ B=B,
95
+ H=H,
96
+ HV=HV,
97
+ K=K,
98
+ V=V,
99
+ BK=BK,
100
+ BV=BV,
101
+ # stride_init_state_token=stride_init_state_token,
102
+ # stride_final_state_token=stride_final_state_token,
103
+ # stride_indices_seq=stride_indices_seq,
104
+ # stride_indices_tok=stride_indices_tok,
105
+ IS_BETA_HEADWISE=beta.ndim == v.ndim,
106
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
107
+ # INPLACE_FINAL_STATE=inplace_final_state,
108
+ IS_KDA=True,
109
+ num_warps=num_warps,
110
+ num_stages=num_stages,
111
+ )
112
+
113
+ return o, final_state
114
+
115
+
116
+ def fused_recurrent_kda(
117
+ q: torch.Tensor,
118
+ k: torch.Tensor,
119
+ v: torch.Tensor,
120
+ g: torch.Tensor,
121
+ beta: torch.Tensor = None,
122
+ scale: float = None,
123
+ initial_state: torch.Tensor = None,
124
+ inplace_final_state: bool = True,
125
+ use_qk_l2norm_in_kernel: bool = True,
126
+ cu_seqlens: torch.LongTensor | None = None,
127
+ # ssm_state_indices: torch.LongTensor | None = None,
128
+ **kwargs,
129
+ ) -> tuple[torch.Tensor, torch.Tensor]:
130
+ if cu_seqlens is not None and q.shape[0] != 1:
131
+ raise ValueError(
132
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
133
+ f"Please flatten variable-length inputs before processing."
134
+ )
135
+ if scale is None:
136
+ scale = k.shape[-1] ** -0.5
137
+
138
+ o, final_state = fused_recurrent_kda_fwd(
139
+ q=q.contiguous(),
140
+ k=k.contiguous(),
141
+ v=v.contiguous(),
142
+ g=g.contiguous(),
143
+ beta=beta.contiguous(),
144
+ scale=scale,
145
+ initial_state=initial_state,
146
+ inplace_final_state=inplace_final_state,
147
+ cu_seqlens=cu_seqlens,
148
+ # ssm_state_indices=ssm_state_indices,
149
+ num_accepted_tokens=None,
150
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
151
+ )
152
+ return o, final_state
153
+
154
+
155
+ @triton.heuristics(
156
+ {
157
+ "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
158
+ "HAS_RESIDUAL": lambda args: args["residual"] is not None,
159
+ "HAS_WEIGHT": lambda args: args["w"] is not None,
160
+ "HAS_BIAS": lambda args: args["b"] is not None,
161
+ }
162
+ )
163
+ @triton.jit
164
+ def layer_norm_gated_fwd_kernel(
165
+ x, # pointer to the input
166
+ g, # pointer to the gate
167
+ y, # pointer to the output
168
+ w, # pointer to the weights
169
+ b, # pointer to the biases
170
+ residual, # pointer to the residual
171
+ residual_out, # pointer to the residual
172
+ mean, # pointer to the mean
173
+ rstd, # pointer to the 1/std
174
+ eps, # epsilon to avoid division by zero
175
+ T, # number of rows in x
176
+ D: tl.constexpr, # number of columns in x
177
+ BT: tl.constexpr,
178
+ BD: tl.constexpr,
179
+ ACTIVATION: tl.constexpr,
180
+ IS_RMS_NORM: tl.constexpr,
181
+ STORE_RESIDUAL_OUT: tl.constexpr,
182
+ HAS_RESIDUAL: tl.constexpr,
183
+ HAS_WEIGHT: tl.constexpr,
184
+ HAS_BIAS: tl.constexpr,
185
+ ):
186
+ i_t = tl.program_id(0)
187
+
188
+ o_d = tl.arange(0, BD)
189
+ m_d = o_d < D
190
+
191
+ p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
192
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
193
+ if HAS_RESIDUAL:
194
+ p_res = tl.make_block_ptr(
195
+ residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
196
+ )
197
+ b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)
198
+ if STORE_RESIDUAL_OUT:
199
+ p_res_out = tl.make_block_ptr(
200
+ residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
201
+ )
202
+ tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))
203
+ if not IS_RMS_NORM:
204
+ b_mean = tl.sum(b_x, axis=1) / D
205
+ p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))
206
+ tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))
207
+ b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)
208
+ b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
209
+ else:
210
+ b_xbar = tl.where(m_d[None, :], b_x, 0.0)
211
+ b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
212
+ b_rstd = 1 / tl.sqrt(b_var + eps)
213
+
214
+ p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
215
+ tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))
216
+
217
+ if HAS_WEIGHT:
218
+ b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
219
+ if HAS_BIAS:
220
+ b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
221
+ b_x_hat = (
222
+ (b_x - b_mean[:, None]) * b_rstd[:, None]
223
+ if not IS_RMS_NORM
224
+ else b_x * b_rstd[:, None]
225
+ )
226
+ b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat
227
+ if HAS_BIAS:
228
+ b_y = b_y + b_b[None, :]
229
+
230
+ # swish/sigmoid output gate
231
+ p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
232
+ b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
233
+ if ACTIVATION == "swish" or ACTIVATION == "silu":
234
+ b_y = b_y * b_g * tl.sigmoid(b_g)
235
+ elif ACTIVATION == "sigmoid":
236
+ b_y = b_y * tl.sigmoid(b_g)
237
+
238
+ # Write output
239
+ p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
240
+ tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
241
+
242
+
243
+ @triton.heuristics(
244
+ {
245
+ "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
246
+ "HAS_RESIDUAL": lambda args: args["residual"] is not None,
247
+ "HAS_WEIGHT": lambda args: args["w"] is not None,
248
+ "HAS_BIAS": lambda args: args["b"] is not None,
249
+ }
250
+ )
251
+ @triton.jit
252
+ def layer_norm_gated_fwd_kernel1(
253
+ x, # pointer to the input
254
+ g, # pointer to the gate
255
+ y, # pointer to the output
256
+ w, # pointer to the weights
257
+ b, # pointer to the biases
258
+ residual, # pointer to the residual
259
+ residual_out, # pointer to the residual
260
+ mean, # pointer to the mean
261
+ rstd, # pointer to the 1/std
262
+ eps, # epsilon to avoid division by zero
263
+ D: tl.constexpr, # number of columns in x
264
+ BD: tl.constexpr,
265
+ ACTIVATION: tl.constexpr,
266
+ IS_RMS_NORM: tl.constexpr,
267
+ STORE_RESIDUAL_OUT: tl.constexpr,
268
+ HAS_RESIDUAL: tl.constexpr,
269
+ HAS_WEIGHT: tl.constexpr,
270
+ HAS_BIAS: tl.constexpr,
271
+ ):
272
+ i_t = tl.program_id(0)
273
+ x += i_t * D
274
+ y += i_t * D
275
+ g += i_t * D
276
+ if HAS_RESIDUAL:
277
+ residual += i_t * D
278
+ if STORE_RESIDUAL_OUT:
279
+ residual_out += i_t * D
280
+
281
+ o_d = tl.arange(0, BD)
282
+ m_d = o_d < D
283
+ b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32)
284
+ if HAS_RESIDUAL:
285
+ b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32)
286
+ if STORE_RESIDUAL_OUT:
287
+ tl.store(residual_out + o_d, b_x, mask=m_d)
288
+ if not IS_RMS_NORM:
289
+ b_mean = tl.sum(b_x, axis=0) / D
290
+ tl.store(mean + i_t, b_mean)
291
+ b_xbar = tl.where(m_d, b_x - b_mean, 0.0)
292
+ b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
293
+ else:
294
+ b_xbar = tl.where(m_d, b_x, 0.0)
295
+ b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
296
+ b_rstd = 1 / tl.sqrt(b_var + eps)
297
+ tl.store(rstd + i_t, b_rstd)
298
+
299
+ if HAS_WEIGHT:
300
+ b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
301
+ if HAS_BIAS:
302
+ b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
303
+ b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd
304
+ b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat
305
+ if HAS_BIAS:
306
+ b_y = b_y + b_b
307
+
308
+ # swish/sigmoid output gate
309
+ b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32)
310
+ if ACTIVATION == "swish" or ACTIVATION == "silu":
311
+ b_y = b_y * b_g * tl.sigmoid(b_g)
312
+ elif ACTIVATION == "sigmoid":
313
+ b_y = b_y * tl.sigmoid(b_g)
314
+
315
+ # Write output
316
+ tl.store(y + o_d, b_y, mask=m_d)
317
+
318
+
319
+ def layer_norm_gated_fwd(
320
+ x: torch.Tensor,
321
+ g: torch.Tensor,
322
+ weight: torch.Tensor,
323
+ bias: torch.Tensor,
324
+ activation: str = "swish",
325
+ eps: float = 1e-5,
326
+ residual: torch.Tensor = None,
327
+ out_dtype: torch.dtype = None,
328
+ residual_dtype: torch.dtype = None,
329
+ is_rms_norm: bool = False,
330
+ ):
331
+ if residual is not None:
332
+ residual_dtype = residual.dtype
333
+ T, D = x.shape
334
+ if residual is not None:
335
+ assert residual.shape == (T, D)
336
+ if weight is not None:
337
+ assert weight.shape == (D,)
338
+ if bias is not None:
339
+ assert bias.shape == (D,)
340
+ # allocate output
341
+ y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype)
342
+ if residual is not None or (
343
+ residual_dtype is not None and residual_dtype != x.dtype
344
+ ):
345
+ residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)
346
+ else:
347
+ residual_out = None
348
+ mean = (
349
+ torch.empty((T,), dtype=torch.float, device=x.device)
350
+ if not is_rms_norm
351
+ else None
352
+ )
353
+ rstd = torch.empty((T,), dtype=torch.float, device=x.device)
354
+ # Less than 64KB per feature: enqueue fused kernel
355
+ MAX_FUSED_SIZE = 65536 // x.element_size()
356
+ BD = min(MAX_FUSED_SIZE, next_power_of_2(D))
357
+ if D > BD:
358
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
359
+ # heuristics for number of warps
360
+
361
+ if D <= 512:
362
+ BT = 32
363
+ layer_norm_gated_fwd_kernel[(cdiv(T, BT),)](
364
+ x=x,
365
+ g=g,
366
+ y=y,
367
+ w=weight,
368
+ b=bias,
369
+ residual=residual,
370
+ residual_out=residual_out,
371
+ mean=mean,
372
+ rstd=rstd,
373
+ eps=eps,
374
+ T=T,
375
+ D=D,
376
+ BD=BD,
377
+ BT=BT,
378
+ ACTIVATION=activation,
379
+ IS_RMS_NORM=is_rms_norm,
380
+ num_warps=4,
381
+ )
382
+ else:
383
+ layer_norm_gated_fwd_kernel1[(T,)](
384
+ x=x,
385
+ g=g,
386
+ y=y,
387
+ w=weight,
388
+ b=bias,
389
+ residual=residual,
390
+ residual_out=residual_out,
391
+ mean=mean,
392
+ rstd=rstd,
393
+ eps=eps,
394
+ D=D,
395
+ BD=BD,
396
+ ACTIVATION=activation,
397
+ IS_RMS_NORM=is_rms_norm,
398
+ num_warps=4,
399
+ )
400
+ # residual_out is None if residual is None and residual_dtype == input_dtype
401
+ return y, mean, rstd, residual_out if residual_out is not None else x
402
+
403
+
404
+ def rms_norm_gated(
405
+ x: torch.Tensor,
406
+ g: torch.Tensor,
407
+ weight: torch.Tensor,
408
+ bias: torch.Tensor,
409
+ activation: str = "swish",
410
+ residual: torch.Tensor | None = None,
411
+ prenorm: bool = False,
412
+ residual_in_fp32: bool = False,
413
+ eps: float = 1e-6,
414
+ ):
415
+ x_shape_og = x.shape
416
+ # reshape input data into 2D tensor
417
+ x = x.contiguous().reshape(-1, x.shape[-1])
418
+ g = g.contiguous().reshape(-1, g.shape[-1])
419
+ if residual is not None:
420
+ assert residual.shape == x_shape_og
421
+ residual = residual.contiguous().reshape(-1, residual.shape[-1])
422
+ residual_dtype = (
423
+ residual.dtype
424
+ if residual is not None
425
+ else (torch.float if residual_in_fp32 else None)
426
+ )
427
+ y, _, _, residual_out = layer_norm_gated_fwd(
428
+ x=x,
429
+ g=g,
430
+ weight=weight,
431
+ bias=bias,
432
+ activation=activation,
433
+ eps=eps,
434
+ residual=residual,
435
+ residual_dtype=residual_dtype,
436
+ is_rms_norm=True,
437
+ )
438
+ y = y.reshape(x_shape_og)
439
+ return y if not prenorm else (y, residual_out.reshape(x_shape_og))
440
+
441
+
442
+ class FusedRMSNormGated(nn.Module):
443
+ def __init__(
444
+ self,
445
+ hidden_size: int,
446
+ elementwise_affine: bool = True,
447
+ eps: float = 1e-5,
448
+ activation: str = "swish",
449
+ device: torch.device | None = None,
450
+ dtype: torch.dtype | None = None,
451
+ ) -> None:
452
+ factory_kwargs = {"device": device, "dtype": dtype}
453
+ super().__init__()
454
+
455
+ self.hidden_size = hidden_size
456
+ self.elementwise_affine = elementwise_affine
457
+ self.eps = eps
458
+ self.activation = activation
459
+
460
+ if self.activation not in ["swish", "silu", "sigmoid"]:
461
+ raise ValueError(f"Unsupported activation: {self.activation}")
462
+
463
+ if elementwise_affine:
464
+ self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
465
+ else:
466
+ self.register_parameter("weight", None)
467
+ self.register_parameter("bias", None)
468
+
469
+ def forward(
470
+ self,
471
+ x: torch.Tensor,
472
+ g: torch.Tensor,
473
+ residual: torch.Tensor | None = None,
474
+ prenorm: bool = False,
475
+ residual_in_fp32: bool = False,
476
+ ) -> torch.Tensor:
477
+ return rms_norm_gated(
478
+ x,
479
+ g,
480
+ self.weight,
481
+ self.bias,
482
+ self.activation,
483
+ residual=residual,
484
+ eps=self.eps,
485
+ prenorm=prenorm,
486
+ residual_in_fp32=residual_in_fp32,
487
+ )
488
+
489
+
490
+ @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
491
+ @triton.autotune(
492
+ configs=[
493
+ triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
494
+ for BK in [32, 64]
495
+ for num_warps in [1, 2, 4, 8]
496
+ for num_stages in [2, 3, 4]
497
+ ],
498
+ key=["BC"],
499
+ )
500
+ @triton.jit(do_not_specialize=["T"])
501
+ def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
502
+ q,
503
+ k,
504
+ g,
505
+ beta,
506
+ A,
507
+ Aqk,
508
+ scale,
509
+ cu_seqlens,
510
+ chunk_indices,
511
+ T,
512
+ H: tl.constexpr,
513
+ K: tl.constexpr,
514
+ BT: tl.constexpr,
515
+ BC: tl.constexpr,
516
+ BK: tl.constexpr,
517
+ NC: tl.constexpr,
518
+ IS_VARLEN: tl.constexpr,
519
+ ):
520
+ i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
521
+ i_b, i_h = i_bh // H, i_bh % H
522
+ i_i, i_j = i_c // NC, i_c % NC
523
+ if IS_VARLEN:
524
+ i_n, i_t = (
525
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
526
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
527
+ )
528
+ bos, eos = (
529
+ tl.load(cu_seqlens + i_n).to(tl.int32),
530
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
531
+ )
532
+ T = eos - bos
533
+ else:
534
+ bos, eos = i_b * T, i_b * T + T
535
+
536
+ if i_t * BT + i_i * BC >= T:
537
+ return
538
+ if i_i <= i_j:
539
+ return
540
+
541
+ q += (bos * H + i_h) * K
542
+ k += (bos * H + i_h) * K
543
+ g += (bos * H + i_h) * K
544
+ A += (bos * H + i_h) * BT
545
+ Aqk += (bos * H + i_h) * BT
546
+
547
+ p_b = tl.make_block_ptr(
548
+ beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)
549
+ )
550
+ b_b = tl.load(p_b, boundary_check=(0,))
551
+
552
+ b_A = tl.zeros([BC, BC], dtype=tl.float32)
553
+ b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
554
+ for i_k in range(tl.cdiv(K, BK)):
555
+ p_q = tl.make_block_ptr(
556
+ q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
557
+ )
558
+ p_k = tl.make_block_ptr(
559
+ k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
560
+ )
561
+ p_g = tl.make_block_ptr(
562
+ g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
563
+ )
564
+ b_kt = tl.make_block_ptr(
565
+ k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
566
+ )
567
+ p_gk = tl.make_block_ptr(
568
+ g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
569
+ )
570
+
571
+ o_k = i_k * BK + tl.arange(0, BK)
572
+ m_k = o_k < K
573
+ # [BK,]
574
+ b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
575
+ # [BC, BK]
576
+ b_g = tl.load(p_g, boundary_check=(0, 1))
577
+ b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
578
+ # [BK, BC]
579
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
580
+ b_kt = tl.load(b_kt, boundary_check=(0, 1))
581
+ # [BC, BC]
582
+ b_ktg = b_kt * exp(b_gn[:, None] - b_gk)
583
+ b_A += tl.dot(b_k, b_ktg)
584
+
585
+ b_q = tl.load(p_q, boundary_check=(0, 1))
586
+ b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
587
+ b_Aqk += tl.dot(b_qg, b_ktg)
588
+
589
+ b_A *= b_b[:, None]
590
+
591
+ p_A = tl.make_block_ptr(
592
+ A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
593
+ )
594
+ tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
595
+ p_Aqk = tl.make_block_ptr(
596
+ Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
597
+ )
598
+ tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))
599
+
600
+
601
+ @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
602
+ @triton.autotune(
603
+ configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
604
+ key=["BK", "BT"],
605
+ )
606
+ @triton.jit(do_not_specialize=["T"])
607
+ def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
608
+ q,
609
+ k,
610
+ g,
611
+ beta,
612
+ A,
613
+ Aqk,
614
+ scale,
615
+ cu_seqlens,
616
+ chunk_indices,
617
+ T,
618
+ H: tl.constexpr,
619
+ K: tl.constexpr,
620
+ BT: tl.constexpr,
621
+ BC: tl.constexpr,
622
+ BK: tl.constexpr,
623
+ IS_VARLEN: tl.constexpr,
624
+ ):
625
+ i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
626
+ i_b, i_h = i_bh // H, i_bh % H
627
+ if IS_VARLEN:
628
+ i_n, i_t = (
629
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
630
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
631
+ )
632
+ bos, eos = (
633
+ tl.load(cu_seqlens + i_n).to(tl.int32),
634
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
635
+ )
636
+ T = eos - bos
637
+ else:
638
+ bos, eos = i_b * T, i_b * T + T
639
+
640
+ if i_t * BT + i_i * BC >= T:
641
+ return
642
+
643
+ o_i = tl.arange(0, BC)
644
+ o_k = tl.arange(0, BK)
645
+ m_k = o_k < K
646
+ m_A = (i_t * BT + i_i * BC + o_i) < T
647
+ o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
648
+
649
+ p_q = tl.make_block_ptr(
650
+ q + (bos * H + i_h) * K,
651
+ (T, K),
652
+ (H * K, 1),
653
+ (i_t * BT + i_i * BC, 0),
654
+ (BC, BK),
655
+ (1, 0),
656
+ )
657
+ p_k = tl.make_block_ptr(
658
+ k + (bos * H + i_h) * K,
659
+ (T, K),
660
+ (H * K, 1),
661
+ (i_t * BT + i_i * BC, 0),
662
+ (BC, BK),
663
+ (1, 0),
664
+ )
665
+ p_g = tl.make_block_ptr(
666
+ g + (bos * H + i_h) * K,
667
+ (T, K),
668
+ (H * K, 1),
669
+ (i_t * BT + i_i * BC, 0),
670
+ (BC, BK),
671
+ (1, 0),
672
+ )
673
+ b_q = tl.load(p_q, boundary_check=(0, 1))
674
+ b_k = tl.load(p_k, boundary_check=(0, 1))
675
+ b_g = tl.load(p_g, boundary_check=(0, 1))
676
+
677
+ p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
678
+ b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None]
679
+
680
+ p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
681
+ p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
682
+
683
+ for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
684
+ b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
685
+ b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
686
+ b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :])
687
+ b_A = tl.sum(b_k * b_ktg, 1)
688
+ b_A = tl.where(o_i > j, b_A, 0.0)
689
+ b_Aqk = tl.sum(b_q * b_ktg, 1)
690
+ b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0)
691
+ tl.store(A + o_A + j, b_A, mask=m_A)
692
+ tl.store(Aqk + o_A + j, b_Aqk, mask=m_A)
693
+ p_kt += H * K
694
+ p_gk += H * K
695
+
696
+
697
+ def chunk_kda_scaled_dot_kkt_fwd(
698
+ q: torch.Tensor,
699
+ k: torch.Tensor,
700
+ gk: torch.Tensor | None = None,
701
+ beta: torch.Tensor | None = None,
702
+ scale: float | None = None,
703
+ cu_seqlens: torch.LongTensor | None = None,
704
+ chunk_size: int = 64,
705
+ output_dtype: torch.dtype = torch.float32,
706
+ ) -> tuple[torch.Tensor, torch.Tensor]:
707
+ r"""
708
+ Compute beta * K * K^T.
709
+
710
+ Args:
711
+ k (torch.Tensor):
712
+ The key tensor of shape `[B, T, H, K]`.
713
+ beta (torch.Tensor):
714
+ The beta tensor of shape `[B, T, H]`.
715
+ gk (torch.Tensor):
716
+ The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
717
+ cu_seqlens (torch.LongTensor):
718
+ The cumulative sequence lengths of the input tensor.
719
+ Default: None
720
+ chunk_size (int):
721
+ The chunk size. Default: 64.
722
+ output_dtype (torch.dtype):
723
+ The dtype of the output tensor. Default: `torch.float32`
724
+
725
+ Returns:
726
+ beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
727
+ """
728
+ B, T, H, K = k.shape
729
+ assert K <= 256
730
+ BT = chunk_size
731
+ chunk_indices = (
732
+ prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
733
+ )
734
+ NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
735
+
736
+ BC = min(16, BT)
737
+ NC = cdiv(BT, BC)
738
+ BK = max(next_power_of_2(K), 16)
739
+ A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
740
+ Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
741
+ grid = (NT, NC * NC, B * H)
742
+ chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
743
+ q=q,
744
+ k=k,
745
+ g=gk,
746
+ beta=beta,
747
+ A=A,
748
+ Aqk=Aqk,
749
+ scale=scale,
750
+ cu_seqlens=cu_seqlens,
751
+ chunk_indices=chunk_indices,
752
+ T=T,
753
+ H=H,
754
+ K=K,
755
+ BT=BT,
756
+ BC=BC,
757
+ NC=NC,
758
+ )
759
+
760
+ grid = (NT, NC, B * H)
761
+ chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
762
+ q=q,
763
+ k=k,
764
+ g=gk,
765
+ beta=beta,
766
+ A=A,
767
+ Aqk=Aqk,
768
+ scale=scale,
769
+ cu_seqlens=cu_seqlens,
770
+ chunk_indices=chunk_indices,
771
+ T=T,
772
+ H=H,
773
+ K=K,
774
+ BT=BT,
775
+ BC=BC,
776
+ BK=BK,
777
+ )
778
+ return A, Aqk
779
+
780
+
781
+ @triton.heuristics(
782
+ {
783
+ "STORE_QG": lambda args: args["qg"] is not None,
784
+ "STORE_KG": lambda args: args["kg"] is not None,
785
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
786
+ }
787
+ )
788
+ @triton.autotune(
789
+ configs=[
790
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
791
+ for num_warps in [2, 4, 8]
792
+ for num_stages in [2, 3, 4]
793
+ ],
794
+ key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
795
+ )
796
+ @triton.jit(do_not_specialize=["T"])
797
+ def recompute_w_u_fwd_kernel(
798
+ q,
799
+ k,
800
+ qg,
801
+ kg,
802
+ v,
803
+ beta,
804
+ w,
805
+ u,
806
+ A,
807
+ gk,
808
+ cu_seqlens,
809
+ chunk_indices,
810
+ T,
811
+ H: tl.constexpr,
812
+ K: tl.constexpr,
813
+ V: tl.constexpr,
814
+ BT: tl.constexpr,
815
+ BK: tl.constexpr,
816
+ BV: tl.constexpr,
817
+ STORE_QG: tl.constexpr,
818
+ STORE_KG: tl.constexpr,
819
+ IS_VARLEN: tl.constexpr,
820
+ DOT_PRECISION: tl.constexpr,
821
+ ):
822
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
823
+ i_b, i_h = i_bh // H, i_bh % H
824
+ if IS_VARLEN:
825
+ i_n, i_t = (
826
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
827
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
828
+ )
829
+ bos, eos = (
830
+ tl.load(cu_seqlens + i_n).to(tl.int32),
831
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
832
+ )
833
+ T = eos - bos
834
+ else:
835
+ bos, eos = i_b * T, i_b * T + T
836
+ p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
837
+ b_b = tl.load(p_b, boundary_check=(0,))
838
+
839
+ p_A = tl.make_block_ptr(
840
+ A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
841
+ )
842
+ b_A = tl.load(p_A, boundary_check=(0, 1))
843
+
844
+ for i_v in range(tl.cdiv(V, BV)):
845
+ p_v = tl.make_block_ptr(
846
+ v + (bos * H + i_h) * V,
847
+ (T, V),
848
+ (H * V, 1),
849
+ (i_t * BT, i_v * BV),
850
+ (BT, BV),
851
+ (1, 0),
852
+ )
853
+ p_u = tl.make_block_ptr(
854
+ u + (bos * H + i_h) * V,
855
+ (T, V),
856
+ (H * V, 1),
857
+ (i_t * BT, i_v * BV),
858
+ (BT, BV),
859
+ (1, 0),
860
+ )
861
+ b_v = tl.load(p_v, boundary_check=(0, 1))
862
+ b_vb = (b_v * b_b[:, None]).to(b_v.dtype)
863
+ b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION)
864
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
865
+
866
+ for i_k in range(tl.cdiv(K, BK)):
867
+ p_w = tl.make_block_ptr(
868
+ w + (bos * H + i_h) * K,
869
+ (T, K),
870
+ (H * K, 1),
871
+ (i_t * BT, i_k * BK),
872
+ (BT, BK),
873
+ (1, 0),
874
+ )
875
+ p_k = tl.make_block_ptr(
876
+ k + (bos * H + i_h) * K,
877
+ (T, K),
878
+ (H * K, 1),
879
+ (i_t * BT, i_k * BK),
880
+ (BT, BK),
881
+ (1, 0),
882
+ )
883
+ b_k = tl.load(p_k, boundary_check=(0, 1))
884
+ b_kb = b_k * b_b[:, None]
885
+
886
+ p_gk = tl.make_block_ptr(
887
+ gk + (bos * H + i_h) * K,
888
+ (T, K),
889
+ (H * K, 1),
890
+ (i_t * BT, i_k * BK),
891
+ (BT, BK),
892
+ (1, 0),
893
+ )
894
+ b_gk = tl.load(p_gk, boundary_check=(0, 1))
895
+ b_kb *= exp(b_gk)
896
+ if STORE_QG:
897
+ p_q = tl.make_block_ptr(
898
+ q + (bos * H + i_h) * K,
899
+ (T, K),
900
+ (H * K, 1),
901
+ (i_t * BT, i_k * BK),
902
+ (BT, BK),
903
+ (1, 0),
904
+ )
905
+ p_qg = tl.make_block_ptr(
906
+ qg + (bos * H + i_h) * K,
907
+ (T, K),
908
+ (H * K, 1),
909
+ (i_t * BT, i_k * BK),
910
+ (BT, BK),
911
+ (1, 0),
912
+ )
913
+ b_q = tl.load(p_q, boundary_check=(0, 1))
914
+ b_qg = b_q * exp(b_gk)
915
+ tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1))
916
+ if STORE_KG:
917
+ last_idx = min(i_t * BT + BT, T) - 1
918
+
919
+ o_k = i_k * BK + tl.arange(0, BK)
920
+ m_k = o_k < K
921
+ b_gn = tl.load(
922
+ gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0
923
+ )
924
+ b_kg = b_k * exp(b_gn - b_gk)
925
+
926
+ p_kg = tl.make_block_ptr(
927
+ kg + (bos * H + i_h) * K,
928
+ (T, K),
929
+ (H * K, 1),
930
+ (i_t * BT, i_k * BK),
931
+ (BT, BK),
932
+ (1, 0),
933
+ )
934
+ tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1))
935
+
936
+ b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
937
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
938
+
939
+
940
+ def recompute_w_u_fwd(
941
+ k: torch.Tensor,
942
+ v: torch.Tensor,
943
+ beta: torch.Tensor,
944
+ A: torch.Tensor,
945
+ q: torch.Tensor | None = None,
946
+ gk: torch.Tensor | None = None,
947
+ cu_seqlens: torch.LongTensor | None = None,
948
+ ) -> tuple[torch.Tensor, torch.Tensor]:
949
+ B, T, H, K, V = *k.shape, v.shape[-1]
950
+ BT = A.shape[-1]
951
+ BK = 64
952
+ BV = 64
953
+
954
+ chunk_indices = (
955
+ prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
956
+ )
957
+ NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
958
+
959
+ w = torch.empty_like(k)
960
+ u = torch.empty_like(v)
961
+ kg = torch.empty_like(k) if gk is not None else None
962
+ recompute_w_u_fwd_kernel[(NT, B * H)](
963
+ q=q,
964
+ k=k,
965
+ qg=None,
966
+ kg=kg,
967
+ v=v,
968
+ beta=beta,
969
+ w=w,
970
+ u=u,
971
+ A=A,
972
+ gk=gk,
973
+ cu_seqlens=cu_seqlens,
974
+ chunk_indices=chunk_indices,
975
+ T=T,
976
+ H=H,
977
+ K=K,
978
+ V=V,
979
+ BT=BT,
980
+ BK=BK,
981
+ BV=BV,
982
+ DOT_PRECISION="ieee",
983
+ )
984
+ return w, u, None, kg
985
+
986
+
987
+ @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
988
+ @triton.autotune(
989
+ configs=[
990
+ triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
991
+ for BK in [32, 64]
992
+ for BV in [64, 128]
993
+ for num_warps in [2, 4, 8]
994
+ for num_stages in [2, 3, 4]
995
+ ],
996
+ key=["BT"],
997
+ )
998
+ @triton.jit(do_not_specialize=["T"])
999
+ def chunk_gla_fwd_kernel_o(
1000
+ q,
1001
+ v,
1002
+ g,
1003
+ h,
1004
+ o,
1005
+ A,
1006
+ cu_seqlens,
1007
+ chunk_indices,
1008
+ scale,
1009
+ T,
1010
+ H: tl.constexpr,
1011
+ K: tl.constexpr,
1012
+ V: tl.constexpr,
1013
+ BT: tl.constexpr,
1014
+ BK: tl.constexpr,
1015
+ BV: tl.constexpr,
1016
+ IS_VARLEN: tl.constexpr,
1017
+ ):
1018
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
1019
+ i_b, i_h = i_bh // H, i_bh % H
1020
+ if IS_VARLEN:
1021
+ i_tg = i_t
1022
+ i_n, i_t = (
1023
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
1024
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
1025
+ )
1026
+ bos, eos = (
1027
+ tl.load(cu_seqlens + i_n).to(tl.int32),
1028
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
1029
+ )
1030
+ T = eos - bos
1031
+ NT = tl.cdiv(T, BT)
1032
+ else:
1033
+ NT = tl.cdiv(T, BT)
1034
+ i_tg = i_b * NT + i_t
1035
+ bos, eos = i_b * T, i_b * T + T
1036
+
1037
+ m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
1038
+
1039
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
1040
+ for i_k in range(tl.cdiv(K, BK)):
1041
+ p_q = tl.make_block_ptr(
1042
+ q + (bos * H + i_h) * K,
1043
+ (T, K),
1044
+ (H * K, 1),
1045
+ (i_t * BT, i_k * BK),
1046
+ (BT, BK),
1047
+ (1, 0),
1048
+ )
1049
+ p_g = tl.make_block_ptr(
1050
+ g + (bos * H + i_h) * K,
1051
+ (T, K),
1052
+ (H * K, 1),
1053
+ (i_t * BT, i_k * BK),
1054
+ (BT, BK),
1055
+ (1, 0),
1056
+ )
1057
+ p_h = tl.make_block_ptr(
1058
+ h + (i_tg * H + i_h) * K * V,
1059
+ (K, V),
1060
+ (V, 1),
1061
+ (i_k * BK, i_v * BV),
1062
+ (BK, BV),
1063
+ (1, 0),
1064
+ )
1065
+
1066
+ # [BT, BK]
1067
+ b_q = tl.load(p_q, boundary_check=(0, 1))
1068
+ b_q = (b_q * scale).to(b_q.dtype)
1069
+ # [BT, BK]
1070
+ b_g = tl.load(p_g, boundary_check=(0, 1))
1071
+ # [BT, BK]
1072
+ b_qg = (b_q * exp(b_g)).to(b_q.dtype)
1073
+ # [BK, BV]
1074
+ b_h = tl.load(p_h, boundary_check=(0, 1))
1075
+ # works but dkw, owing to divine benevolence
1076
+ # [BT, BV]
1077
+ if i_k >= 0:
1078
+ b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
1079
+ p_v = tl.make_block_ptr(
1080
+ v + (bos * H + i_h) * V,
1081
+ (T, V),
1082
+ (H * V, 1),
1083
+ (i_t * BT, i_v * BV),
1084
+ (BT, BV),
1085
+ (1, 0),
1086
+ )
1087
+ p_o = tl.make_block_ptr(
1088
+ o + (bos * H + i_h) * V,
1089
+ (T, V),
1090
+ (H * V, 1),
1091
+ (i_t * BT, i_v * BV),
1092
+ (BT, BV),
1093
+ (1, 0),
1094
+ )
1095
+ p_A = tl.make_block_ptr(
1096
+ A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
1097
+ )
1098
+ # [BT, BV]
1099
+ b_v = tl.load(p_v, boundary_check=(0, 1))
1100
+ # [BT, BT]
1101
+ b_A = tl.load(p_A, boundary_check=(0, 1))
1102
+ b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype)
1103
+ b_o += tl.dot(b_A, b_v, allow_tf32=False)
1104
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
1105
+
1106
+
1107
+ def chunk_gla_fwd_o_gk(
1108
+ q: torch.Tensor,
1109
+ v: torch.Tensor,
1110
+ g: torch.Tensor,
1111
+ A: torch.Tensor,
1112
+ h: torch.Tensor,
1113
+ o: torch.Tensor,
1114
+ scale: float,
1115
+ cu_seqlens: torch.LongTensor | None = None,
1116
+ chunk_size: int = 64,
1117
+ ):
1118
+ B, T, H, K, V = *q.shape, v.shape[-1]
1119
+ BT = chunk_size
1120
+
1121
+ chunk_indices = (
1122
+ prepare_chunk_indices(cu_seqlens, chunk_size)
1123
+ if cu_seqlens is not None
1124
+ else None
1125
+ )
1126
+ NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
1127
+
1128
+ def grid(meta):
1129
+ return (cdiv(V, meta["BV"]), NT, B * H)
1130
+
1131
+ chunk_gla_fwd_kernel_o[grid](
1132
+ q=q,
1133
+ v=v,
1134
+ g=g,
1135
+ h=h,
1136
+ o=o,
1137
+ A=A,
1138
+ cu_seqlens=cu_seqlens,
1139
+ chunk_indices=chunk_indices,
1140
+ scale=scale,
1141
+ T=T,
1142
+ H=H,
1143
+ K=K,
1144
+ V=V,
1145
+ BT=BT,
1146
+ )
1147
+ return o
1148
+
1149
+
1150
+ def chunk_kda_fwd(
1151
+ q: torch.Tensor,
1152
+ k: torch.Tensor,
1153
+ v: torch.Tensor,
1154
+ g: torch.Tensor,
1155
+ beta: torch.Tensor,
1156
+ scale: float,
1157
+ initial_state: torch.Tensor,
1158
+ output_final_state: bool,
1159
+ cu_seqlens: torch.LongTensor | None = None,
1160
+ ):
1161
+ chunk_size = 64
1162
+ g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
1163
+ # the intra Aqk is kept in fp32
1164
+ # the computation has very marginal effect on the entire throughput
1165
+ A, Aqk = chunk_kda_scaled_dot_kkt_fwd(
1166
+ q=q,
1167
+ k=k,
1168
+ gk=g,
1169
+ beta=beta,
1170
+ scale=scale,
1171
+ cu_seqlens=cu_seqlens,
1172
+ output_dtype=torch.float32,
1173
+ )
1174
+ A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
1175
+ w, u, _, kg = recompute_w_u_fwd(
1176
+ k=k,
1177
+ v=v,
1178
+ beta=beta,
1179
+ A=A,
1180
+ gk=g,
1181
+ cu_seqlens=cu_seqlens,
1182
+ )
1183
+ del A
1184
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
1185
+ k=kg,
1186
+ w=w,
1187
+ u=u,
1188
+ gk=g,
1189
+ initial_state=initial_state,
1190
+ output_final_state=output_final_state,
1191
+ cu_seqlens=cu_seqlens,
1192
+ )
1193
+ del w, u, kg
1194
+ o = chunk_gla_fwd_o_gk(
1195
+ q=q,
1196
+ v=v_new,
1197
+ g=g,
1198
+ A=Aqk,
1199
+ h=h,
1200
+ o=v,
1201
+ scale=scale,
1202
+ cu_seqlens=cu_seqlens,
1203
+ chunk_size=chunk_size,
1204
+ )
1205
+ del Aqk, v_new, h
1206
+ return o, final_state
1207
+
1208
+
1209
+ def chunk_kda(
1210
+ q: torch.Tensor,
1211
+ k: torch.Tensor,
1212
+ v: torch.Tensor,
1213
+ g: torch.Tensor,
1214
+ beta: torch.Tensor,
1215
+ scale: float = None,
1216
+ initial_state: torch.Tensor = None,
1217
+ output_final_state: bool = False,
1218
+ use_qk_l2norm_in_kernel: bool = False,
1219
+ cu_seqlens: torch.LongTensor | None = None,
1220
+ **kwargs,
1221
+ ):
1222
+ if scale is None:
1223
+ scale = k.shape[-1] ** -0.5
1224
+
1225
+ if use_qk_l2norm_in_kernel:
1226
+ q = l2norm_fwd(q.contiguous())
1227
+ k = l2norm_fwd(k.contiguous())
1228
+
1229
+ o, final_state = chunk_kda_fwd(
1230
+ q=q,
1231
+ k=k,
1232
+ v=v.contiguous(),
1233
+ g=g.contiguous(),
1234
+ beta=beta.contiguous(),
1235
+ scale=scale,
1236
+ initial_state=initial_state.contiguous(),
1237
+ output_final_state=output_final_state,
1238
+ cu_seqlens=cu_seqlens,
1239
+ )
1240
+ return o, final_state
1241
+
1242
+
1243
+ @triton.autotune(
1244
+ configs=[
1245
+ triton.Config({"BT": bt}, num_warps=nw, num_stages=ns)
1246
+ for bt in BT_LIST_AUTOTUNE
1247
+ for nw in NUM_WARPS_AUTOTUNE
1248
+ for ns in [2, 3]
1249
+ ],
1250
+ key=["H", "D"],
1251
+ )
1252
+ @triton.jit
1253
+ def kda_gate_fwd_kernel(
1254
+ g,
1255
+ A,
1256
+ y,
1257
+ g_bias,
1258
+ beta: tl.constexpr,
1259
+ threshold: tl.constexpr,
1260
+ T,
1261
+ H,
1262
+ D: tl.constexpr,
1263
+ BT: tl.constexpr,
1264
+ BD: tl.constexpr,
1265
+ HAS_BIAS: tl.constexpr,
1266
+ ):
1267
+ i_t, i_h = tl.program_id(0), tl.program_id(1)
1268
+ n_t = i_t * BT
1269
+
1270
+ b_a = tl.load(A + i_h).to(tl.float32)
1271
+ b_a = -tl.exp(b_a)
1272
+
1273
+ stride_row = H * D
1274
+ stride_col = 1
1275
+
1276
+ g_ptr = tl.make_block_ptr(
1277
+ base=g + i_h * D,
1278
+ shape=(T, D),
1279
+ strides=(stride_row, stride_col),
1280
+ offsets=(n_t, 0),
1281
+ block_shape=(BT, BD),
1282
+ order=(1, 0),
1283
+ )
1284
+
1285
+ y_ptr = tl.make_block_ptr(
1286
+ base=y + i_h * D,
1287
+ shape=(T, D),
1288
+ strides=(stride_row, stride_col),
1289
+ offsets=(n_t, 0),
1290
+ block_shape=(BT, BD),
1291
+ order=(1, 0),
1292
+ )
1293
+
1294
+ b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32)
1295
+
1296
+ if HAS_BIAS:
1297
+ n_d = tl.arange(0, BD)
1298
+ bias_mask = n_d < D
1299
+ b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to(
1300
+ tl.float32
1301
+ )
1302
+ b_g = b_g + b_bias[None, :]
1303
+
1304
+ # softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
1305
+ # When beta * x > threshold, use linear approximation x
1306
+ # Use threshold to switch to linear when beta*x > threshold
1307
+ g_scaled = b_g * beta
1308
+ use_linear = g_scaled > threshold
1309
+ sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled)))
1310
+ b_y = b_a * sp
1311
+
1312
+ tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
1313
+
1314
+
1315
+ def fused_kda_gate(
1316
+ g: torch.Tensor,
1317
+ A: torch.Tensor,
1318
+ head_k_dim: int,
1319
+ g_bias: torch.Tensor | None = None,
1320
+ beta: float = 1.0,
1321
+ threshold: float = 20.0,
1322
+ ) -> torch.Tensor:
1323
+ """
1324
+ Forward pass for KDA gate:
1325
+ input g: [..., H*D]
1326
+ param A: [H] or [1, 1, H, 1]
1327
+ beta: softplus beta parameter
1328
+ threshold: softplus threshold parameter
1329
+ return : [..., H, D]
1330
+ """
1331
+ orig_shape = g.shape[:-1]
1332
+
1333
+ g = g.view(-1, g.shape[-1])
1334
+ T = g.shape[0]
1335
+ HD = g.shape[1]
1336
+ H = A.numel()
1337
+ assert H * head_k_dim == HD
1338
+
1339
+ y = torch.empty_like(g, dtype=torch.float32)
1340
+
1341
+ def grid(meta):
1342
+ return (cdiv(T, meta["BT"]), H)
1343
+
1344
+ kda_gate_fwd_kernel[grid](
1345
+ g,
1346
+ A,
1347
+ y,
1348
+ g_bias,
1349
+ beta,
1350
+ threshold,
1351
+ T,
1352
+ H,
1353
+ head_k_dim,
1354
+ BD=next_power_of_2(head_k_dim),
1355
+ HAS_BIAS=g_bias is not None,
1356
+ )
1357
+
1358
+ y = y.view(*orig_shape, H, head_k_dim)
1359
+ return y