rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,190 @@
1
+ """
2
+ JAX 版 RWKV7 单步 wkv kernel(仅前向传播)
3
+ 延迟编译 CUDA 扩展,专为 T=1 场景优化
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import pathlib
8
+ import subprocess
9
+ import ctypes
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from typing import Optional, Tuple, Union
13
+
14
+ # ---------- 延迟编译(改到当前目录) ----------
15
+ _CURRENT_DIR = pathlib.Path(__file__).parent.absolute()
16
+
17
+
18
+ def get_jax_generalized_delta_rule_single_step(HEAD_SIZE=64):
19
+ _BUILD_DIR = _CURRENT_DIR / f"build_single_step_{HEAD_SIZE}"
20
+ _SO_PATH = _CURRENT_DIR / f"build_single_step_{HEAD_SIZE}/wkv7_single_step.so"
21
+
22
+ def _ensure_compiled() -> pathlib.Path:
23
+ """首次调用时编译 CUDA 扩展,产出放在当前源码目录"""
24
+ if _SO_PATH.exists():
25
+ return _SO_PATH
26
+
27
+ print("[rwkv7_single_step_jax] First use – compiling CUDA kernel…")
28
+ src_dir = _CURRENT_DIR
29
+ build_dir = _BUILD_DIR
30
+ build_dir.mkdir(exist_ok=True)
31
+
32
+ # 获取 XLA 头文件路径
33
+ xla_include_dir = jax.ffi.include_dir()
34
+ if not xla_include_dir:
35
+ raise RuntimeError("jax.ffi.include_dir() 返回空,请检查 JAX >= 0.4.31")
36
+
37
+ # CUDA 编译 flags(移除 CHUNK_LEN 定义)
38
+ cuda_flags = [
39
+ "-ftz=true",
40
+ "-prec-div=false",
41
+ "-prec-sqrt=false",
42
+ "--use_fast_math",
43
+ "-O3",
44
+ "-Xptxas=-O3",
45
+ "-res-usage",
46
+ "--extra-device-vectorization",
47
+ f"-D_C_={HEAD_SIZE}",
48
+ ]
49
+
50
+ # CMake 配置
51
+ cmake_args = [
52
+ "cmake",
53
+ "-S",
54
+ str(src_dir),
55
+ "-B",
56
+ str(build_dir),
57
+ "-DCMAKE_BUILD_TYPE=Release",
58
+ f"-DCMAKE_INSTALL_PREFIX={_CURRENT_DIR}",
59
+ f"-DXLA_INCLUDE_DIR={xla_include_dir}",
60
+ f"-DCMAKE_CUDA_FLAGS={' '.join(cuda_flags)}",
61
+ ]
62
+ subprocess.check_call(cmake_args)
63
+
64
+ # 构建与安装
65
+ subprocess.check_call(["cmake", "--build", str(build_dir), "-j"])
66
+ subprocess.check_call(["cmake", "--install", str(build_dir)])
67
+
68
+ if not _SO_PATH.exists():
69
+ raise RuntimeError("Compilation failed – wkv7_single_step.so not found.")
70
+
71
+ print("[rwkv7_single_step_jax] Compilation finished – output at", _SO_PATH)
72
+ return _SO_PATH
73
+
74
+ # 注册 FFI 符号(仅前向)
75
+ _lib = ctypes.CDLL(_ensure_compiled())
76
+ jax.ffi.register_ffi_target(
77
+ "wkv7_single_step_fwd",
78
+ jax.ffi.pycapsule(_lib.Wkv7SingleStepFwd),
79
+ platform="CUDA",
80
+ )
81
+
82
+ # ---------- 工具 ----------
83
+ def _transpose_head(x: jnp.ndarray, head_first: bool) -> jnp.ndarray:
84
+ """(B, 1, H, K) <-> (B, H, 1, K)"""
85
+ x = jnp.asarray(x, dtype=jnp.bfloat16)
86
+ if head_first:
87
+ return jnp.transpose(x, (0, 2, 1, 3))
88
+ return x
89
+
90
+ # ---------- 前向 kernel ----------
91
+ def _wkv7_single_step_kernel(
92
+ w: jnp.ndarray,
93
+ q: jnp.ndarray,
94
+ k: jnp.ndarray,
95
+ v: jnp.ndarray,
96
+ a: jnp.ndarray,
97
+ b: jnp.ndarray,
98
+ h0: jnp.ndarray,
99
+ ):
100
+ """
101
+ 内部 kernel 接口
102
+ 参数: w,q,k,v,a,b,h0 -> y,s
103
+ """
104
+ B, H, K = q.shape
105
+ dtype = q.dtype
106
+
107
+ out_type = jax.ShapeDtypeStruct((B, H, K), dtype)
108
+ s_type = jax.ShapeDtypeStruct((B, H, K, K), jnp.float32)
109
+
110
+ y, s = jax.ffi.ffi_call(
111
+ "wkv7_single_step_fwd", (out_type, s_type), vmap_method="broadcast_all"
112
+ )(w, q, k, v, a, b, h0)
113
+
114
+ return y, s
115
+
116
+ def wk7_single_step_kernel(
117
+ w: jnp.ndarray,
118
+ q: jnp.ndarray,
119
+ k: jnp.ndarray,
120
+ v: jnp.ndarray,
121
+ a: jnp.ndarray,
122
+ b: jnp.ndarray,
123
+ h0: jnp.ndarray,
124
+ ):
125
+ """前向计算函数"""
126
+ y, s = _wkv7_single_step_kernel(w, q, k, v, a, b, h0)
127
+ final_state = s # 单步后直接返回状态
128
+ return (y, final_state)
129
+
130
+ # ---------- 主接口 ----------
131
+ def generalized_delta_rule_single_step(
132
+ r: jnp.ndarray,
133
+ w: jnp.ndarray,
134
+ k: jnp.ndarray,
135
+ v: jnp.ndarray,
136
+ a: jnp.ndarray,
137
+ b: jnp.ndarray,
138
+ initial_state: Optional[jnp.ndarray] = None,
139
+ output_final_state: bool = True,
140
+ head_first: bool = False,
141
+ ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
142
+ """
143
+ 单步广义 delta 规则(仅前向)
144
+ 参数:
145
+ r,w,k,v,a,b: 输入张量,形状必须为 (B, 1, H, K) 或 (B, H, 1, K)
146
+ initial_state: 可选 (B, H, K, K) 初始状态,None 则零初始化
147
+ output_final_state: 是否同时返回最后状态
148
+ head_first: 是否将 head 维提前
149
+ 返回:
150
+ out: (B, 1, H, K) 与输入 dtype 一致
151
+ last_state: (B, H, K, K) 当 output_final_state=True
152
+ """
153
+ # 统一转 (B, 1, H, K) 并验证 T=1
154
+ r = _transpose_head(r, head_first)
155
+ w = _transpose_head(w, head_first)
156
+ k = _transpose_head(k, head_first)
157
+ v = _transpose_head(v, head_first)
158
+ a = _transpose_head(a, head_first)
159
+ b = _transpose_head(b, head_first)
160
+
161
+ B, T, H, K = r.shape
162
+ if T != 1:
163
+ raise ValueError(f"Single-step kernel requires T=1, but got T={T}.")
164
+
165
+ # 处理初始状态
166
+ if initial_state is None:
167
+ h0 = jnp.zeros((B, H, K, K), jnp.float32)
168
+ else:
169
+ h0 = jnp.asarray(initial_state, jnp.float32)
170
+
171
+ # 移除 T 维度后调用 kernel
172
+ r = r[:, 0, :, :] # (B, H, K)
173
+ w = w[:, 0, :, :]
174
+ k = k[:, 0, :, :]
175
+ v = v[:, 0, :, :]
176
+ a = a[:, 0, :, :]
177
+ b = b[:, 0, :, :]
178
+
179
+ # 调用前向 kernel
180
+ out, last_state = wk7_single_step_kernel(w, r, k, v, a, b, h0)
181
+
182
+ # 恢复 T 维度
183
+ out = jnp.expand_dims(out, axis=1) # (B, 1, H, K)
184
+ out = jnp.asarray(out, r.dtype)
185
+
186
+ if output_final_state:
187
+ return out, last_state
188
+ return out
189
+
190
+ return generalized_delta_rule_single_step
@@ -0,0 +1,9 @@
1
+ from ..jax_kernel.chunk_A_fwd import *
2
+ from ..jax_kernel.chunk_A_bwd import *
3
+ from ..jax_kernel.chunk_h_fwd import *
4
+ from ..jax_kernel.chunk_h_bwd import *
5
+ from ..jax_kernel.chunk_o_fwd import *
6
+ from ..jax_kernel.chunk_o_bwd import *
7
+ from ..jax_kernel.cumsum import *
8
+ from ..jax_kernel.wy_fast_fwd import *
9
+ from ..jax_kernel.wy_fast_bwd import *
@@ -0,0 +1,95 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ import jax_triton as jt
5
+ import jax
6
+ import triton
7
+ from ..triton_kernel.chunk_A_bwd import *
8
+ from ..triton_kernel.utils import is_gather_supported
9
+ from ..get_torch_devices_info import check_shared_mem
10
+
11
+
12
+ def chunk_dplr_bwd_dqk_intra(
13
+ q: jax.Array,
14
+ k: jax.Array,
15
+ a: jax.Array,
16
+ b: jax.Array,
17
+ gi: jax.Array,
18
+ ge: jax.Array,
19
+ dAqk: jax.Array,
20
+ dAqb: jax.Array,
21
+ dAak: jax.Array,
22
+ dAab: jax.Array,
23
+ dqg: jax.Array,
24
+ dkg: jax.Array,
25
+ dag: jax.Array,
26
+ dbg: jax.Array,
27
+ dgk_last: jax.Array,
28
+ scale: float = 1.0,
29
+ chunk_size: int = 16,
30
+ ):
31
+ B, T, H, K = q.shape
32
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
33
+ BK = (
34
+ min(64, triton.next_power_of_2(K))
35
+ if check_shared_mem()
36
+ else min(32, triton.next_power_of_2(K))
37
+ )
38
+
39
+ NT = triton.cdiv(T, BT)
40
+ NK = triton.cdiv(K, BK)
41
+ grid = (NK, NT, B * H)
42
+
43
+ out_shapes = [
44
+ jax.ShapeDtypeStruct(q.shape, q.dtype),
45
+ jax.ShapeDtypeStruct(k.shape, k.dtype),
46
+ jax.ShapeDtypeStruct(a.shape, a.dtype),
47
+ jax.ShapeDtypeStruct(b.shape, b.dtype),
48
+ jax.ShapeDtypeStruct(gi.shape, "float32"),
49
+ jax.ShapeDtypeStruct(gi.shape, "float32"),
50
+ ]
51
+
52
+ dq, dk, da, db, dgk, dgk_offset = jt.triton_call(
53
+ q,
54
+ k,
55
+ a,
56
+ b,
57
+ gi,
58
+ ge,
59
+ dAqk,
60
+ dAqb,
61
+ dAak,
62
+ dAab,
63
+ dqg,
64
+ dkg,
65
+ dag,
66
+ dbg,
67
+ T,
68
+ scale=scale,
69
+ H=H,
70
+ K=K,
71
+ BT=BT,
72
+ BC=BT,
73
+ BK=BK,
74
+ GATHER_SUPPORTED=is_gather_supported,
75
+ kernel=chunk_dplr_bwd_kernel_intra,
76
+ out_shape=out_shapes,
77
+ grid=grid,
78
+ )
79
+
80
+ def grid(meta):
81
+ return (NT, triton.cdiv(K, meta["BK"]), B * H)
82
+
83
+ dgk_output = jt.triton_call(
84
+ dgk,
85
+ dgk_offset,
86
+ dgk_last,
87
+ T,
88
+ H=H,
89
+ K=K,
90
+ BT=BT,
91
+ kernel=chunk_dplr_bwd_dgk_kernel,
92
+ out_shape=[jax.ShapeDtypeStruct(dgk.shape, "float32")],
93
+ grid=grid,
94
+ )[0]
95
+ return dq, dk, da, db, dgk_output
@@ -0,0 +1,60 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+
5
+ import jax_triton as jt
6
+ import jax
7
+ import triton
8
+
9
+ from ..triton_kernel.utils import is_gather_supported
10
+
11
+ from ..triton_kernel.chunk_A_fwd import *
12
+
13
+
14
+ def chunk_dplr_fwd_intra(
15
+ q: jax.Array,
16
+ k: jax.Array,
17
+ a: jax.Array,
18
+ b: jax.Array,
19
+ gi: jax.Array,
20
+ ge: jax.Array,
21
+ scale: float,
22
+ chunk_size: int,
23
+ ):
24
+ B, T, H, K = k.shape
25
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
26
+
27
+ NT = triton.cdiv(T, BT)
28
+ shape = [B, T, H, BT]
29
+ out_shapes = [
30
+ jax.ShapeDtypeStruct(q.shape, q.dtype),
31
+ jax.ShapeDtypeStruct(k.shape, q.dtype),
32
+ jax.ShapeDtypeStruct(a.shape, q.dtype),
33
+ jax.ShapeDtypeStruct(b.shape, q.dtype),
34
+ jax.ShapeDtypeStruct(shape, q.dtype),
35
+ jax.ShapeDtypeStruct(shape, q.dtype),
36
+ jax.ShapeDtypeStruct(shape, "float32"),
37
+ jax.ShapeDtypeStruct(shape, "float32"),
38
+ ]
39
+ grid = (NT, B, H)
40
+ BK = triton.next_power_of_2(K)
41
+ qg, kg, ag, bg, Aqk, Aqb, Aab, Aak = jt.triton_call(
42
+ q,
43
+ k,
44
+ a,
45
+ b,
46
+ gi,
47
+ ge,
48
+ T,
49
+ scale=scale,
50
+ H=H,
51
+ K=K,
52
+ BT=BT,
53
+ BC=BT,
54
+ BK=BK,
55
+ GATHER_SUPPORTED=is_gather_supported,
56
+ kernel=chunk_dplr_fwd_A_kernel_intra_sub_intra,
57
+ out_shape=out_shapes,
58
+ grid=grid,
59
+ )
60
+ return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
@@ -0,0 +1,78 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+ from ..get_jax_devices_info import check_shared_mem
11
+ from ..triton_kernel.chunk_h_bwd import *
12
+
13
+
14
+ def chunk_dplr_bwd_dhu(
15
+ qg: jax.Array,
16
+ bg: jax.Array,
17
+ w: jax.Array,
18
+ gk: jax.Array,
19
+ h0: jax.Array,
20
+ dht: Optional[jax.Array],
21
+ do: jax.Array,
22
+ dv: jax.Array,
23
+ chunk_size: int = 64,
24
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
25
+ B, T, H, K, V = *qg.shape, do.shape[-1]
26
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
27
+ BK = triton.next_power_of_2(K)
28
+ assert BK <= 256, (
29
+ "current kernel does not support head dimension being larger than 256."
30
+ )
31
+ # H100
32
+ if check_shared_mem("hopper"):
33
+ BV = 64
34
+ BC = 64 if K <= 128 else 32
35
+ elif check_shared_mem("ampere"): # A100
36
+ BV = 32
37
+ BC = 32
38
+ else: # Etc: 4090
39
+ BV = 16
40
+ BC = 16
41
+
42
+ N, NT = B, triton.cdiv(T, BT)
43
+ BC = min(BT, BC)
44
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
45
+ assert NK == 1, (
46
+ "NK > 1 is not supported because it involves time-consuming synchronization"
47
+ )
48
+ dh_shape = (B, NT, H, K, V)
49
+ out_shapes = [
50
+ jax.ShapeDtypeStruct(dh_shape, dv.dtype),
51
+ jax.ShapeDtypeStruct((B, H, K, V), "float32"),
52
+ jax.ShapeDtypeStruct(dv.shape, dv.dtype),
53
+ ]
54
+
55
+ grid = (NK, NV, N * H)
56
+ dh, dh0, dv2 = jt.triton_call(
57
+ qg,
58
+ bg,
59
+ w,
60
+ gk,
61
+ dht,
62
+ dv,
63
+ do,
64
+ T,
65
+ H=H,
66
+ K=K,
67
+ V=V,
68
+ BT=BT,
69
+ BC=BC,
70
+ BK=BK,
71
+ BV=BV,
72
+ kernel=chunk_dplr_bwd_kernel_dhu.fn,
73
+ out_shape=out_shapes,
74
+ grid=grid,
75
+ USE_FINAL_STATE_GRADIENT=dht is not None,
76
+ USE_INITIAL_STATE=h0 is not None,
77
+ )
78
+ return dh, dh0, dv2
@@ -0,0 +1,80 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+ from ..get_jax_devices_info import check_shared_mem
11
+ from ..triton_kernel.chunk_h_fwd import *
12
+
13
+
14
+ def chunk_dplr_fwd_h(
15
+ kg: jax.Array,
16
+ v: jax.Array,
17
+ w: jax.Array,
18
+ u: jax.Array,
19
+ bg: jax.Array,
20
+ gk: jax.Array,
21
+ initial_state: Optional[jax.Array] = None,
22
+ output_final_state: bool = False,
23
+ chunk_size: int = 64,
24
+ ) -> Tuple[jax.Array, jax.Array]:
25
+ B, T, H, K, V = *kg.shape, u.shape[-1]
26
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
27
+
28
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
29
+ BK = triton.next_power_of_2(K)
30
+ assert BK <= 256, "current kernel does not support head dimension larger than 256."
31
+ # H100 can have larger block size
32
+
33
+ if check_shared_mem("hopper"):
34
+ BV = 64
35
+ BC = 64 if K <= 128 else 32
36
+ elif check_shared_mem("ampere"): # A100
37
+ BV = 32
38
+ BC = 32
39
+ else:
40
+ BV = 16
41
+ BC = 16
42
+
43
+ BC = min(BT, BC)
44
+ NK = triton.cdiv(K, BK)
45
+ NV = triton.cdiv(V, BV)
46
+ assert NK == 1, (
47
+ "NK > 1 is not supported because it involves time-consuming synchronization"
48
+ )
49
+
50
+ out_shapes = [
51
+ jax.ShapeDtypeStruct((B, NT, H, K, V), kg.dtype),
52
+ jax.ShapeDtypeStruct([N, H, K, V], "float32"),
53
+ jax.ShapeDtypeStruct(u.shape, u.dtype),
54
+ ]
55
+ grid = (NK, NV, N * H)
56
+ if initial_state is None:
57
+ initial_state = jax.numpy.zeros([N, H, K, V], "float32")
58
+ h, final_state, v_new = jt.triton_call(
59
+ kg,
60
+ v,
61
+ w,
62
+ bg,
63
+ u,
64
+ gk,
65
+ initial_state,
66
+ T,
67
+ H=H,
68
+ K=K,
69
+ V=V,
70
+ BT=BT,
71
+ BC=BC,
72
+ BK=BK,
73
+ BV=BV,
74
+ kernel=chunk_dplr_fwd_kernel_h.fn,
75
+ out_shape=out_shapes,
76
+ grid=grid,
77
+ STORE_FINAL_STATE=True,
78
+ USE_INITIAL_STATE=True,
79
+ )
80
+ return h, v_new, final_state
@@ -0,0 +1,150 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Tuple
5
+
6
+ import jax_triton as jt
7
+ import jax
8
+ import triton
9
+
10
+ from ..get_jax_devices_info import check_shared_mem
11
+ from ..triton_kernel.chunk_o_bwd import *
12
+
13
+
14
+ def chunk_dplr_bwd_dv(
15
+ A_qk: jax.Array,
16
+ kg: jax.Array,
17
+ do: jax.Array,
18
+ dh: jax.Array,
19
+ chunk_size: int = 64,
20
+ ) -> jax.Array:
21
+ B, T, H, K, V = *kg.shape, do.shape[-1]
22
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
23
+
24
+ NT = triton.cdiv(T, BT)
25
+
26
+ def grid(meta):
27
+ return (triton.cdiv(V, meta["BV"]), NT, B * H)
28
+
29
+ dv = jt.triton_call(
30
+ A_qk,
31
+ kg,
32
+ do,
33
+ dh,
34
+ T,
35
+ H=H,
36
+ K=K,
37
+ V=V,
38
+ BT=BT,
39
+ kernel=chunk_dplr_bwd_kernel_dv,
40
+ out_shape=jax.ShapeDtypeStruct(do.shape, do.dtype),
41
+ grid=grid,
42
+ )
43
+ return dv
44
+
45
+
46
+ def chunk_dplr_bwd_o(
47
+ k: jax.Array,
48
+ b: jax.Array,
49
+ v: jax.Array,
50
+ v_new: jax.Array,
51
+ gk: jax.Array,
52
+ do: jax.Array,
53
+ h: jax.Array,
54
+ dh: jax.Array,
55
+ dv: jax.Array,
56
+ w: jax.Array,
57
+ chunk_size: int = 64,
58
+ scale: float = 1.0,
59
+ ) -> Tuple[jax.Array, jax.Array, jax.Array]:
60
+ B, T, H, K, V = *w.shape, v.shape[-1]
61
+
62
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
63
+ NT = triton.cdiv(T, BT)
64
+
65
+ BK = (
66
+ min(triton.next_power_of_2(K), 64)
67
+ if check_shared_mem()
68
+ else min(triton.next_power_of_2(K), 32)
69
+ )
70
+ BV = (
71
+ min(triton.next_power_of_2(V), 64)
72
+ if check_shared_mem()
73
+ else min(triton.next_power_of_2(K), 32)
74
+ )
75
+ NK = triton.cdiv(K, BK)
76
+ grid = (NK, NT, B * H)
77
+
78
+ out_shapes = [
79
+ jax.ShapeDtypeStruct(k.shape, k.dtype),
80
+ jax.ShapeDtypeStruct(k.shape, k.dtype),
81
+ jax.ShapeDtypeStruct(w.shape, w.dtype),
82
+ jax.ShapeDtypeStruct(b.shape, b.dtype),
83
+ jax.ShapeDtypeStruct([B, NT, H, K], "float32"),
84
+ ]
85
+ dq, dk, dw, db, dgk_last = jt.triton_call(
86
+ v,
87
+ v_new,
88
+ h,
89
+ do,
90
+ dh,
91
+ w,
92
+ dv,
93
+ gk,
94
+ k,
95
+ b,
96
+ T,
97
+ H=H,
98
+ K=K,
99
+ V=V,
100
+ BT=BT,
101
+ BK=BK,
102
+ BV=BV,
103
+ kernel=chunk_dplr_bwd_o_kernel,
104
+ out_shape=out_shapes,
105
+ grid=grid,
106
+ )
107
+ return (dq, dk, dw, db, dgk_last)
108
+
109
+
110
+ def chunk_dplr_bwd_dAu(
111
+ v: jax.Array,
112
+ v_new: jax.Array,
113
+ do: jax.Array,
114
+ A_qb: jax.Array,
115
+ scale: float,
116
+ chunk_size: int = 64,
117
+ ) -> jax.Array:
118
+ B, T, H, V = v.shape
119
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
120
+ NT = triton.cdiv(T, BT)
121
+
122
+ if check_shared_mem("ampere"): # A100
123
+ BV = min(triton.next_power_of_2(V), 128)
124
+ elif check_shared_mem("ada"): # 4090
125
+ BV = min(triton.next_power_of_2(V), 64)
126
+ else:
127
+ BV = min(triton.next_power_of_2(V), 32)
128
+
129
+ grid = (NT, B * H)
130
+ out_shapes = [
131
+ jax.ShapeDtypeStruct([B, T, H, BT], "float32"),
132
+ jax.ShapeDtypeStruct([B, T, H, BT], "float32"),
133
+ jax.ShapeDtypeStruct(v_new.shape, v_new.dtype),
134
+ ]
135
+ dA_qk, dA_qb, dv_new = jt.triton_call(
136
+ v,
137
+ do,
138
+ v_new,
139
+ A_qb,
140
+ T,
141
+ scale=scale,
142
+ H=H,
143
+ V=V,
144
+ BT=BT,
145
+ BV=BV,
146
+ grid=grid,
147
+ out_shape=out_shapes,
148
+ kernel=chunk_dplr_bwd_kernel_dAu,
149
+ )
150
+ return dv_new, dA_qk, dA_qb
@@ -0,0 +1,45 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025,Qingwen Lin
3
+
4
+
5
+ import jax_triton as jt
6
+ import jax
7
+ import triton
8
+
9
+ from ..triton_kernel.chunk_o_fwd import *
10
+
11
+
12
+ def chunk_dplr_fwd_o(
13
+ qg: jax.Array,
14
+ v: jax.Array,
15
+ v_new: jax.Array,
16
+ A_qk: jax.Array,
17
+ A_qb: jax.Array,
18
+ h: jax.Array,
19
+ chunk_size: int = 64,
20
+ ) -> jax.Array:
21
+ B, T, H, K, V = *qg.shape, v.shape[-1]
22
+ BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
23
+
24
+ NT = triton.cdiv(T, BT)
25
+
26
+ def grid(meta):
27
+ return (triton.cdiv(V, meta["BV"]), NT, B * H)
28
+
29
+ o = jt.triton_call(
30
+ qg,
31
+ v,
32
+ v_new,
33
+ A_qk,
34
+ A_qb,
35
+ h,
36
+ T,
37
+ H=H,
38
+ K=K,
39
+ V=V,
40
+ BT=BT,
41
+ kernel=chunk_dplr_fwd_kernel_o,
42
+ out_shape=jax.ShapeDtypeStruct(v.shape, v.dtype),
43
+ grid=grid,
44
+ )
45
+ return o