sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -189,6 +189,7 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
189
189
  from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
190
190
  GDNAttnBackend,
191
191
  HybridLinearAttnBackend,
192
+ KimiLinearAttnBackend,
192
193
  Mamba2AttnBackend,
193
194
  )
194
195
  from sglang.srt.utils import is_blackwell, is_npu
@@ -207,6 +208,8 @@ def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBac
207
208
  linear_attn_backend = GDNAttnBackend(runner)
208
209
  elif runner.mamba2_config is not None:
209
210
  linear_attn_backend = Mamba2AttnBackend(runner)
211
+ elif runner.kimi_linear_config is not None:
212
+ linear_attn_backend = KimiLinearAttnBackend(runner)
210
213
  else:
211
214
  raise ValueError(
212
215
  "Expected hybrid GDN or NemotronH models, but got unknown model."
@@ -21,6 +21,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
21
21
  @triton.heuristics(
22
22
  {
23
23
  "USE_G": lambda args: args["g"] is not None,
24
+ "USE_GK": lambda args: args["gk"] is not None,
24
25
  "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
25
26
  "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
26
27
  "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
@@ -44,6 +45,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
44
45
  w,
45
46
  v_new,
46
47
  g,
48
+ gk,
47
49
  h,
48
50
  h0,
49
51
  ht,
@@ -57,6 +59,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
57
59
  BT: tl.constexpr,
58
60
  BV: tl.constexpr,
59
61
  USE_G: tl.constexpr,
62
+ USE_GK: tl.constexpr,
60
63
  USE_INITIAL_STATE: tl.constexpr,
61
64
  STORE_FINAL_STATE: tl.constexpr,
62
65
  SAVE_NEW_VALUE: tl.constexpr,
@@ -86,12 +89,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
86
89
  b_h4 = tl.zeros([64, BV], dtype=tl.float32)
87
90
 
88
91
  # calculate offset
89
- h += (boh * H + i_h) * K * V
90
- v += (bos * H + i_h) * V
91
- k += (bos * Hg + i_h // (H // Hg)) * K
92
- w += (bos * H + i_h) * K
92
+ h += ((boh * H + i_h) * K * V).to(tl.int64)
93
+ v += ((bos * H + i_h) * V).to(tl.int64)
94
+ k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
95
+ w += ((bos * H + i_h) * K).to(tl.int64)
93
96
  if SAVE_NEW_VALUE:
94
- v_new += (bos * H + i_h) * V
97
+ v_new += ((bos * H + i_h) * V).to(tl.int64)
95
98
  stride_v = H * V
96
99
  stride_h = H * K * V
97
100
  stride_k = Hg * K
@@ -143,58 +146,48 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
143
146
  )
144
147
  tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
145
148
 
146
- p_v = tl.make_block_ptr(
147
- v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
148
- )
149
- p_v_new = (
150
- tl.make_block_ptr(
151
- v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
152
- )
153
- if SAVE_NEW_VALUE
154
- else None
155
- )
156
- b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
157
149
  p_w = tl.make_block_ptr(
158
150
  w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
159
151
  )
160
152
  b_w = tl.load(p_w, boundary_check=(0, 1))
161
- b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
153
+ b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
162
154
  if K > 64:
163
155
  p_w = tl.make_block_ptr(
164
156
  w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
165
157
  )
166
158
  b_w = tl.load(p_w, boundary_check=(0, 1))
167
- b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
159
+ b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
168
160
  if K > 128:
169
161
  p_w = tl.make_block_ptr(
170
162
  w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
171
163
  )
172
164
  b_w = tl.load(p_w, boundary_check=(0, 1))
173
- b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
165
+ b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
174
166
  if K > 192:
175
167
  p_w = tl.make_block_ptr(
176
168
  w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
177
169
  )
178
170
  b_w = tl.load(p_w, boundary_check=(0, 1))
179
- b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
180
- b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
171
+ b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
172
+ p_v = tl.make_block_ptr(
173
+ v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
174
+ )
175
+ b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
181
176
 
182
177
  if SAVE_NEW_VALUE:
183
- p_v_new = tl.make_block_ptr(
178
+ p_v = tl.make_block_ptr(
184
179
  v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
185
180
  )
186
- tl.store(
187
- p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
188
- )
181
+ tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
189
182
 
183
+ last_idx = min((i_t + 1) * BT, T) - 1
190
184
  if USE_G:
191
- last_idx = min((i_t + 1) * BT, T) - 1
192
185
  b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
193
186
  p_g = tl.make_block_ptr(
194
187
  g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
195
188
  )
196
189
  b_g = tl.load(p_g, boundary_check=(0,))
197
- b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
190
+ b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
198
191
  b_g_last = exp(b_g_last)
199
192
  b_h1 = b_h1 * b_g_last
200
193
  if K > 64:
@@ -203,30 +196,64 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
203
196
  b_h3 = b_h3 * b_g_last
204
197
  if K > 192:
205
198
  b_h4 = b_h4 * b_g_last
206
- b_v_new = b_v_new.to(k.dtype.element_ty)
199
+
200
+ if USE_GK:
201
+ o_k1 = tl.arange(0, 64)
202
+ b_gk_last1 = tl.load(
203
+ gk + (bos + last_idx) * H * K + i_h * K + o_k1,
204
+ mask=(o_k1 < K),
205
+ other=0.0,
206
+ )
207
+ b_h1 *= exp(b_gk_last1)[:, None]
208
+ if K > 64:
209
+ o_k2 = 64 + o_k1
210
+ b_gk_last2 = tl.load(
211
+ gk + (bos + last_idx) * H * K + i_h * K + o_k2,
212
+ mask=(o_k2 < K),
213
+ other=0.0,
214
+ )
215
+ b_h2 *= exp(b_gk_last2)[:, None]
216
+ if K > 128:
217
+ o_k3 = 128 + o_k1
218
+ b_gk_last3 = tl.load(
219
+ gk + (bos + last_idx) * H * K + i_h * K + o_k3,
220
+ mask=(o_k3 < K),
221
+ other=0.0,
222
+ )
223
+ b_h3 *= exp(b_gk_last3)[:, None]
224
+ if K > 192:
225
+ o_k4 = 192 + o_k1
226
+ b_gk_last4 = tl.load(
227
+ gk + (bos + last_idx) * H * K + i_h * K + o_k4,
228
+ mask=(o_k4 < K),
229
+ other=0.0,
230
+ )
231
+ b_h4 *= exp(b_gk_last4)[:, None]
232
+ b_v = b_v.to(k.dtype.element_ty)
233
+
207
234
  p_k = tl.make_block_ptr(
208
235
  k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
209
236
  )
210
237
  b_k = tl.load(p_k, boundary_check=(0, 1))
211
- b_h1 += tl.dot(b_k, b_v_new)
238
+ b_h1 += tl.dot(b_k, b_v)
212
239
  if K > 64:
213
240
  p_k = tl.make_block_ptr(
214
241
  k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
215
242
  )
216
243
  b_k = tl.load(p_k, boundary_check=(0, 1))
217
- b_h2 += tl.dot(b_k, b_v_new)
244
+ b_h2 += tl.dot(b_k, b_v)
218
245
  if K > 128:
219
246
  p_k = tl.make_block_ptr(
220
247
  k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
221
248
  )
222
249
  b_k = tl.load(p_k, boundary_check=(0, 1))
223
- b_h3 += tl.dot(b_k, b_v_new)
250
+ b_h3 += tl.dot(b_k, b_v)
224
251
  if K > 192:
225
252
  p_k = tl.make_block_ptr(
226
253
  k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
227
254
  )
228
255
  b_k = tl.load(p_k, boundary_check=(0, 1))
229
- b_h4 += tl.dot(b_k, b_v_new)
256
+ b_h4 += tl.dot(b_k, b_v)
230
257
 
231
258
  # epilogue
232
259
  if STORE_FINAL_STATE:
@@ -254,6 +281,7 @@ def chunk_gated_delta_rule_fwd_h(
254
281
  w: torch.Tensor,
255
282
  u: torch.Tensor,
256
283
  g: Optional[torch.Tensor] = None,
284
+ gk: Optional[torch.Tensor] = None,
257
285
  initial_state: Optional[torch.Tensor] = None,
258
286
  output_final_state: bool = False,
259
287
  chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
@@ -296,6 +324,7 @@ def chunk_gated_delta_rule_fwd_h(
296
324
  w=w,
297
325
  v_new=v_new,
298
326
  g=g,
327
+ gk=gk,
299
328
  h=h,
300
329
  h0=initial_state,
301
330
  ht=final_state,
@@ -44,6 +44,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
44
44
  IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
45
45
  USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
46
46
  IS_VARLEN: tl.constexpr,
47
+ IS_KDA: tl.constexpr,
47
48
  ):
48
49
  i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
49
50
  i_n, i_hv = i_nh // HV, i_nh % HV
@@ -67,7 +68,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
67
68
  p_beta = beta + (bos * HV + i_hv) * V + o_v
68
69
  else:
69
70
  p_beta = beta + bos * HV + i_hv
70
- p_g = g + bos * HV + i_hv
71
+ if not IS_KDA:
72
+ p_g = g + bos * HV + i_hv
73
+ else:
74
+ p_gk = g + (bos * HV + i_hv) * K + o_k
75
+
71
76
  p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
72
77
 
73
78
  mask_k = o_k < K
@@ -83,14 +88,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
83
88
  b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
84
89
  b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
85
90
  b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
86
- b_g = tl.load(p_g).to(tl.float32)
87
91
 
88
92
  if USE_QK_L2NORM_IN_KERNEL:
89
93
  b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
90
94
  b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
91
95
  b_q = b_q * scale
92
96
  # [BK, BV]
93
- b_h *= exp(b_g)
97
+ if not IS_KDA:
98
+ b_g = tl.load(p_g).to(tl.float32)
99
+ b_h *= exp(b_g)
100
+ else:
101
+ b_gk = tl.load(p_gk).to(tl.float32)
102
+ b_h *= exp(b_gk[:, None])
94
103
  # [BV]
95
104
  b_v -= tl.sum(b_h * b_k[:, None], 0)
96
105
  if IS_BETA_HEADWISE:
@@ -108,7 +117,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
108
117
  p_k += H * K
109
118
  p_o += HV * V
110
119
  p_v += HV * V
111
- p_g += HV
120
+ if not IS_KDA:
121
+ p_g += HV
122
+ else:
123
+ p_gk += HV * K
112
124
  p_beta += HV * (V if IS_BETA_HEADWISE else 1)
113
125
 
114
126
  if STORE_FINAL_STATE:
@@ -165,6 +177,7 @@ def fused_recurrent_gated_delta_rule_fwd(
165
177
  BV=BV,
166
178
  IS_BETA_HEADWISE=beta.ndim == v.ndim,
167
179
  USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
180
+ IS_KDA=False,
168
181
  num_warps=num_warps,
169
182
  num_stages=num_stages,
170
183
  )