sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,242 @@
1
+ # Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ import warnings
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from einops import rearrange
10
+
11
+ from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
12
+ from sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o
13
+ from sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import (
14
+ chunk_scaled_dot_kkt_fwd,
15
+ )
16
+ from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
17
+ from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
18
+ from sglang.srt.layers.attention.fla.solve_tril import solve_tril
19
+ from sglang.srt.layers.attention.fla.utils import (
20
+ SUPPRESS_LEVEL,
21
+ autocast_custom_fwd,
22
+ input_guard,
23
+ )
24
+ from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd
25
+
26
+
27
+ def chunk_gated_delta_rule_fwd(
28
+ q: torch.Tensor,
29
+ k: torch.Tensor,
30
+ v: torch.Tensor,
31
+ g: torch.Tensor,
32
+ beta: torch.Tensor,
33
+ scale: float,
34
+ initial_state: torch.Tensor,
35
+ output_final_state: bool,
36
+ cu_seqlens: Optional[torch.LongTensor] = None,
37
+ ):
38
+ g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
39
+ # obtain WY representation. u is actually the new v.
40
+ A = chunk_scaled_dot_kkt_fwd(
41
+ k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
42
+ )
43
+ A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
44
+ w, u = recompute_w_u_fwd(
45
+ k=k,
46
+ v=v,
47
+ beta=beta,
48
+ A=A,
49
+ g_cumsum=g,
50
+ cu_seqlens=cu_seqlens,
51
+ )
52
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
53
+ k=k,
54
+ w=w,
55
+ u=u,
56
+ g=g,
57
+ initial_state=initial_state,
58
+ output_final_state=output_final_state,
59
+ cu_seqlens=cu_seqlens,
60
+ )
61
+ o = chunk_fwd_o(
62
+ q=q,
63
+ k=k,
64
+ v=v_new,
65
+ h=h,
66
+ g=g,
67
+ scale=scale,
68
+ cu_seqlens=cu_seqlens,
69
+ )
70
+ if SUPPRESS_LEVEL < 3:
71
+ return g, o, A, final_state, None, None, None
72
+ elif SUPPRESS_LEVEL >= 3:
73
+ return g, o, A, final_state, w, h, v_new
74
+
75
+
76
+ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
77
+
78
+ @staticmethod
79
+ @input_guard
80
+ @autocast_custom_fwd
81
+ def forward(
82
+ ctx,
83
+ q: torch.Tensor,
84
+ k: torch.Tensor,
85
+ v: torch.Tensor,
86
+ g: torch.Tensor,
87
+ beta: torch.Tensor,
88
+ scale: float,
89
+ initial_state: torch.Tensor,
90
+ output_final_state: bool,
91
+ cu_seqlens: Optional[torch.LongTensor] = None,
92
+ use_qk_l2norm_in_kernel: bool = False,
93
+ ):
94
+ q_orig = q
95
+ k_orig = k
96
+
97
+ if use_qk_l2norm_in_kernel:
98
+ q = l2norm_fwd(q)
99
+ k = l2norm_fwd(k)
100
+
101
+ g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
102
+ q=q,
103
+ k=k,
104
+ v=v,
105
+ g=g,
106
+ beta=beta,
107
+ scale=scale,
108
+ initial_state=initial_state,
109
+ output_final_state=output_final_state,
110
+ cu_seqlens=cu_seqlens,
111
+ )
112
+ return o.to(q.dtype), final_state
113
+
114
+
115
+ @torch.compiler.disable
116
+ def chunk_gated_delta_rule(
117
+ q: torch.Tensor,
118
+ k: torch.Tensor,
119
+ v: torch.Tensor,
120
+ g: torch.Tensor,
121
+ beta: torch.Tensor,
122
+ scale: float = None,
123
+ initial_state: torch.Tensor = None,
124
+ output_final_state: bool = False,
125
+ cu_seqlens: Optional[torch.LongTensor] = None,
126
+ head_first: bool = False,
127
+ use_qk_l2norm_in_kernel: bool = False,
128
+ ):
129
+ r"""
130
+ Args:
131
+ q (torch.Tensor):
132
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
133
+ k (torch.Tensor):
134
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
135
+ v (torch.Tensor):
136
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
137
+ g (torch.Tensor):
138
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
139
+ beta (torch.Tensor):
140
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
141
+ scale (Optional[int]):
142
+ Scale factor for the RetNet attention scores.
143
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
144
+ initial_state (Optional[torch.Tensor]):
145
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
146
+ For equal-length input sequences, `N` equals the batch size `B`.
147
+ Default: `None`.
148
+ output_final_state (Optional[bool]):
149
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
150
+ cu_seqlens (torch.LongTensor):
151
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
152
+ consistent with the FlashAttention API.
153
+ head_first (Optional[bool]):
154
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
155
+ Default: `False`.
156
+
157
+ Returns:
158
+ o (torch.Tensor):
159
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
160
+ final_state (torch.Tensor):
161
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
162
+
163
+ Examples::
164
+ >>> import torch
165
+ >>> import torch.nn.functional as F
166
+ >>> from einops import rearrange
167
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
168
+ # inputs with equal lengths
169
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
170
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
171
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
172
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
173
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
174
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
175
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
176
+ >>> o, ht = chunk_gated_delta_rule(
177
+ q, k, v, g, beta,
178
+ initial_state=h0,
179
+ output_final_state=True
180
+ )
181
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
182
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
183
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
184
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
185
+ >>> o_var, ht_var = chunk_gated_delta_rule(
186
+ q, k, v, g, beta,
187
+ initial_state=h0,
188
+ output_final_state=True,
189
+ cu_seqlens=cu_seqlens
190
+ )
191
+ """
192
+ assert q.dtype == k.dtype == v.dtype
193
+ assert (
194
+ q.dtype != torch.float32
195
+ ), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
196
+ assert (
197
+ len(beta.shape) == 3
198
+ ), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
199
+
200
+ if head_first:
201
+ raise DeprecationWarning(
202
+ "head_first is deprecated and will be removed in a future version. "
203
+ "Please use head_first=False for now instead."
204
+ )
205
+ q, k, v, beta, g = map(
206
+ lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
207
+ )
208
+ # if not head_first and q.shape[1] < q.shape[2]:
209
+ # warnings.warn(
210
+ # f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
211
+ # "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
212
+ # "when head_first=False was specified. "
213
+ # "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
214
+ # )
215
+ if cu_seqlens is not None:
216
+ if q.shape[0] != 1:
217
+ raise ValueError(
218
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
219
+ f"Please flatten variable-length inputs before processing."
220
+ )
221
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
222
+ raise ValueError(
223
+ f"The number of initial states is expected to be equal to the number of input sequences, "
224
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
225
+ )
226
+ if scale is None:
227
+ scale = k.shape[-1] ** -0.5
228
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
229
+ q,
230
+ k,
231
+ v,
232
+ g,
233
+ beta,
234
+ scale,
235
+ initial_state,
236
+ output_final_state,
237
+ cu_seqlens,
238
+ use_qk_l2norm_in_kernel,
239
+ )
240
+ if head_first:
241
+ o = rearrange(o, "b t h ... -> b h t ...")
242
+ return o, final_state
@@ -0,0 +1,314 @@
1
+ # Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_delta_h.py
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.attention.fla.index import (
12
+ prepare_chunk_indices,
13
+ prepare_chunk_offsets,
14
+ )
15
+ from sglang.srt.layers.attention.fla.op import exp, safe_exp
16
+ from sglang.srt.layers.attention.fla.utils import is_nvidia_hopper
17
+
18
+ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
19
+
20
+
21
+ @triton.heuristics(
22
+ {
23
+ "USE_G": lambda args: args["g"] is not None,
24
+ "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
25
+ "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
26
+ "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
27
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
28
+ }
29
+ )
30
+ # @triton.autotune(
31
+ # configs=[
32
+ # triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages)
33
+ # for num_warps in [2, 4]
34
+ # for num_stages in [2, 3, 4]
35
+ # for BV in [32, 64]
36
+ # ],
37
+ # key=["H", "K", "V", "BT", "USE_G"],
38
+ # use_cuda_graph=use_cuda_graph,
39
+ # )
40
+ @triton.jit(do_not_specialize=["T"])
41
+ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
42
+ k,
43
+ v,
44
+ w,
45
+ v_new,
46
+ g,
47
+ h,
48
+ h0,
49
+ ht,
50
+ cu_seqlens,
51
+ chunk_offsets,
52
+ T,
53
+ H: tl.constexpr,
54
+ Hg: tl.constexpr,
55
+ K: tl.constexpr,
56
+ V: tl.constexpr,
57
+ BT: tl.constexpr,
58
+ BV: tl.constexpr,
59
+ USE_G: tl.constexpr,
60
+ USE_INITIAL_STATE: tl.constexpr,
61
+ STORE_FINAL_STATE: tl.constexpr,
62
+ SAVE_NEW_VALUE: tl.constexpr,
63
+ IS_VARLEN: tl.constexpr,
64
+ ):
65
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
66
+ i_n, i_h = i_nh // H, i_nh % H
67
+ if IS_VARLEN:
68
+ bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
69
+ cu_seqlens + i_n + 1
70
+ ).to(tl.int32)
71
+ T = eos - bos
72
+ NT = tl.cdiv(T, BT)
73
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
74
+ else:
75
+ bos, eos = i_n * T, i_n * T + T
76
+ NT = tl.cdiv(T, BT)
77
+ boh = i_n * NT
78
+
79
+ # [BK, BV]
80
+ b_h1 = tl.zeros([64, BV], dtype=tl.float32)
81
+ if K > 64:
82
+ b_h2 = tl.zeros([64, BV], dtype=tl.float32)
83
+ if K > 128:
84
+ b_h3 = tl.zeros([64, BV], dtype=tl.float32)
85
+ if K > 192:
86
+ b_h4 = tl.zeros([64, BV], dtype=tl.float32)
87
+
88
+ # calculate offset
89
+ h += (boh * H + i_h) * K * V
90
+ v += (bos * H + i_h) * V
91
+ k += (bos * Hg + i_h // (H // Hg)) * K
92
+ w += (bos * H + i_h) * K
93
+ if SAVE_NEW_VALUE:
94
+ v_new += (bos * H + i_h) * V
95
+ stride_v = H * V
96
+ stride_h = H * K * V
97
+ stride_k = Hg * K
98
+ stride_w = H * K
99
+ if USE_INITIAL_STATE:
100
+ h0 = h0 + i_nh * K * V
101
+ if STORE_FINAL_STATE:
102
+ ht = ht + i_nh * K * V
103
+
104
+ # load initial state
105
+ if USE_INITIAL_STATE:
106
+ p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
107
+ b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
108
+ if K > 64:
109
+ p_h0_2 = tl.make_block_ptr(
110
+ h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
111
+ )
112
+ b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
113
+ if K > 128:
114
+ p_h0_3 = tl.make_block_ptr(
115
+ h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
116
+ )
117
+ b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
118
+ if K > 192:
119
+ p_h0_4 = tl.make_block_ptr(
120
+ h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
121
+ )
122
+ b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
123
+
124
+ # main recurrence
125
+ for i_t in range(NT):
126
+ p_h1 = tl.make_block_ptr(
127
+ h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
128
+ )
129
+ tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
130
+ if K > 64:
131
+ p_h2 = tl.make_block_ptr(
132
+ h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
133
+ )
134
+ tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
135
+ if K > 128:
136
+ p_h3 = tl.make_block_ptr(
137
+ h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
138
+ )
139
+ tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
140
+ if K > 192:
141
+ p_h4 = tl.make_block_ptr(
142
+ h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
143
+ )
144
+ tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
145
+
146
+ p_v = tl.make_block_ptr(
147
+ v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
148
+ )
149
+ p_v_new = (
150
+ tl.make_block_ptr(
151
+ v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
152
+ )
153
+ if SAVE_NEW_VALUE
154
+ else None
155
+ )
156
+ b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
157
+ p_w = tl.make_block_ptr(
158
+ w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
159
+ )
160
+ b_w = tl.load(p_w, boundary_check=(0, 1))
161
+ b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
162
+ if K > 64:
163
+ p_w = tl.make_block_ptr(
164
+ w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
165
+ )
166
+ b_w = tl.load(p_w, boundary_check=(0, 1))
167
+ b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
168
+ if K > 128:
169
+ p_w = tl.make_block_ptr(
170
+ w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
171
+ )
172
+ b_w = tl.load(p_w, boundary_check=(0, 1))
173
+ b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
174
+ if K > 192:
175
+ p_w = tl.make_block_ptr(
176
+ w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
177
+ )
178
+ b_w = tl.load(p_w, boundary_check=(0, 1))
179
+ b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
180
+ b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
181
+
182
+ if SAVE_NEW_VALUE:
183
+ p_v_new = tl.make_block_ptr(
184
+ v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
185
+ )
186
+ tl.store(
187
+ p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
188
+ )
189
+
190
+ if USE_G:
191
+ last_idx = min((i_t + 1) * BT, T) - 1
192
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
193
+ p_g = tl.make_block_ptr(
194
+ g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
195
+ )
196
+ b_g = tl.load(p_g, boundary_check=(0,))
197
+ b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
198
+ b_g_last = exp(b_g_last)
199
+ b_h1 = b_h1 * b_g_last
200
+ if K > 64:
201
+ b_h2 = b_h2 * b_g_last
202
+ if K > 128:
203
+ b_h3 = b_h3 * b_g_last
204
+ if K > 192:
205
+ b_h4 = b_h4 * b_g_last
206
+ b_v_new = b_v_new.to(k.dtype.element_ty)
207
+ p_k = tl.make_block_ptr(
208
+ k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
209
+ )
210
+ b_k = tl.load(p_k, boundary_check=(0, 1))
211
+ b_h1 += tl.dot(b_k, b_v_new)
212
+ if K > 64:
213
+ p_k = tl.make_block_ptr(
214
+ k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
215
+ )
216
+ b_k = tl.load(p_k, boundary_check=(0, 1))
217
+ b_h2 += tl.dot(b_k, b_v_new)
218
+ if K > 128:
219
+ p_k = tl.make_block_ptr(
220
+ k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
221
+ )
222
+ b_k = tl.load(p_k, boundary_check=(0, 1))
223
+ b_h3 += tl.dot(b_k, b_v_new)
224
+ if K > 192:
225
+ p_k = tl.make_block_ptr(
226
+ k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
227
+ )
228
+ b_k = tl.load(p_k, boundary_check=(0, 1))
229
+ b_h4 += tl.dot(b_k, b_v_new)
230
+
231
+ # epilogue
232
+ if STORE_FINAL_STATE:
233
+ p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
234
+ tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
235
+ if K > 64:
236
+ p_ht = tl.make_block_ptr(
237
+ ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
238
+ )
239
+ tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
240
+ if K > 128:
241
+ p_ht = tl.make_block_ptr(
242
+ ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
243
+ )
244
+ tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
245
+ if K > 192:
246
+ p_ht = tl.make_block_ptr(
247
+ ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
248
+ )
249
+ tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
250
+
251
+
252
+ def chunk_gated_delta_rule_fwd_h(
253
+ k: torch.Tensor,
254
+ w: torch.Tensor,
255
+ u: torch.Tensor,
256
+ g: Optional[torch.Tensor] = None,
257
+ initial_state: Optional[torch.Tensor] = None,
258
+ output_final_state: bool = False,
259
+ chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
260
+ save_new_value: bool = True,
261
+ cu_seqlens: Optional[torch.LongTensor] = None,
262
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
263
+ B, T, Hg, K, V = *k.shape, u.shape[-1]
264
+ H = u.shape[-2]
265
+ BT = chunk_size
266
+
267
+ chunk_indices = (
268
+ prepare_chunk_indices(cu_seqlens, chunk_size)
269
+ if cu_seqlens is not None
270
+ else None
271
+ )
272
+ # N: the actual number of sequences in the batch with either equal or variable lengths
273
+ if cu_seqlens is None:
274
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
275
+ else:
276
+ N, NT, chunk_offsets = (
277
+ len(cu_seqlens) - 1,
278
+ len(chunk_indices),
279
+ prepare_chunk_offsets(cu_seqlens, BT),
280
+ )
281
+ assert K <= 256, "current kernel does not support head dimension larger than 256."
282
+
283
+ h = k.new_empty(B, NT, H, K, V)
284
+ final_state = (
285
+ k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
286
+ )
287
+
288
+ v_new = torch.empty_like(u) if save_new_value else None
289
+
290
+ def grid(meta):
291
+ return (triton.cdiv(V, meta["BV"]), N * H)
292
+
293
+ chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
294
+ k=k,
295
+ v=u,
296
+ w=w,
297
+ v_new=v_new,
298
+ g=g,
299
+ h=h,
300
+ h0=initial_state,
301
+ ht=final_state,
302
+ cu_seqlens=cu_seqlens,
303
+ chunk_offsets=chunk_offsets,
304
+ T=T,
305
+ H=H,
306
+ Hg=Hg,
307
+ K=K,
308
+ V=V,
309
+ BT=BT,
310
+ BV=32,
311
+ num_warps=4,
312
+ num_stages=2,
313
+ )
314
+ return h, v_new, final_state