rwkv-ops 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (31) hide show
  1. rwkv_ops/__init__.py +5 -6
  2. rwkv_ops/rwkv6_kernel/__init__.py +0 -6
  3. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  4. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  5. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  6. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  7. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  8. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  9. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  10. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  11. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  12. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  13. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
  14. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
  15. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  16. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  17. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
  18. rwkv_ops/rwkv7_kernel/__init__.py +80 -29
  19. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  20. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
  21. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
  22. rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
  23. rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
  24. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
  25. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
  26. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
  27. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/METADATA +28 -27
  28. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/RECORD +30 -13
  29. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/WHEEL +1 -2
  30. rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
  31. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -0,0 +1,237 @@
1
+ """
2
+ JAX 版 RWKV7 wkv kernel + generalized_delta_rule
3
+ 延迟编译 CUDA 扩展,接口与 Torch 版本 1:1 对齐
4
+ """
5
+
6
+ from __future__ import annotations
7
+ import pathlib
8
+ import subprocess
9
+ import ctypes
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from typing import Optional, Tuple, Union
13
+ from jax.ad_checkpoint import checkpoint_policies as cp
14
+ CHUNK_LEN = 16 # 这是一个常数
15
+ # ---------- 延迟编译(改到当前目录) ----------
16
+ _CURRENT_DIR = pathlib.Path(
17
+ __file__
18
+ ).parent.absolute() # rwkv_ops/rwkv7_kernel/jax_cuda_kernel
19
+
20
+
21
+ def get_jax_generalized_delta_rule(HEAD_SIZE=64):
22
+ _BUILD_DIR = _CURRENT_DIR / f"build_{HEAD_SIZE}"
23
+ _SO_PATH = _CURRENT_DIR / f"build_{HEAD_SIZE}/wkv7.so"
24
+
25
+ def _ensure_compiled() -> pathlib.Path:
26
+ """首次调用时编译 CUDA 扩展,产出放在当前源码目录"""
27
+ if _SO_PATH.exists():
28
+ return _SO_PATH
29
+
30
+ print("[rwkv7_jax] First use – compiling CUDA kernel…")
31
+ src_dir = _CURRENT_DIR
32
+ build_dir = _BUILD_DIR
33
+ build_dir.mkdir(exist_ok=True)
34
+
35
+ # ---------- 关键:拿到 JAX 的 XLA 头文件路径 ----------
36
+ xla_include_dir = jax.ffi.include_dir() # 方案 3 核心 API
37
+ if not xla_include_dir:
38
+ raise RuntimeError("jax.ffi.include_dir() 返回空,请检查 JAX >= 0.4.31")
39
+
40
+ # ---------- 关键:把数值稳定性 flag 写死 ----------
41
+ cuda_flags = [
42
+ "-ftz=true", # flush sub-normal to zero
43
+ "-prec-div=false", # 更快除法,避免特殊路径
44
+ "-prec-sqrt=false", # 更快开方
45
+ "--use_fast_math", # 统一 fast math
46
+ "-O3",
47
+ "-Xptxas=-O3",
48
+ "-res-usage",
49
+ "--extra-device-vectorization",
50
+ "-D_C_=64",
51
+ f"-D_C_={HEAD_SIZE}",
52
+ f"-D_CHUNK_LEN_={CHUNK_LEN}",
53
+ ]
54
+
55
+ # 1. 配置
56
+ cmake_args = [
57
+ "cmake",
58
+ "-S",
59
+ str(src_dir),
60
+ "-B",
61
+ str(build_dir),
62
+ "-DCMAKE_BUILD_TYPE=Release",
63
+ f"-DCMAKE_INSTALL_PREFIX={_CURRENT_DIR}",
64
+ f"-DXLA_INCLUDE_DIR={xla_include_dir}", # 传给 CMake
65
+ f"-DCMAKE_CUDA_FLAGS={' '.join(cuda_flags)}",
66
+ ]
67
+ subprocess.check_call(cmake_args)
68
+
69
+ # 2. 构建
70
+ subprocess.check_call(["cmake", "--build", str(build_dir), "-j"])
71
+
72
+ # 3. 安装(把 .so 拷贝到当前目录)
73
+ subprocess.check_call(["cmake", "--install", str(build_dir)])
74
+
75
+ if not _SO_PATH.exists():
76
+ raise RuntimeError("Compilation failed – wkv7.so not found.")
77
+
78
+ print("[rwkv7_jax] Compilation finished – output at", _SO_PATH)
79
+ return _SO_PATH
80
+
81
+ # 注册 FFI 符号
82
+ _lib = ctypes.CDLL(_ensure_compiled())
83
+ jax.ffi.register_ffi_target(
84
+ "wkv7_fwd", jax.ffi.pycapsule(_lib.Wkv7Fwd), platform="CUDA"
85
+ )
86
+ jax.ffi.register_ffi_target(
87
+ "wkv7_bwd", jax.ffi.pycapsule(_lib.Wkv7Bwd), platform="CUDA"
88
+ )
89
+
90
+ # ---------- 工具 ----------
91
+ def _transpose_head(x: jnp.ndarray, head_first: bool) -> jnp.ndarray:
92
+ """(B, T, H, K) <-> (B, H, T, K)"""
93
+ x = jnp.asarray(x, dtype=jnp.bfloat16)
94
+ if head_first:
95
+ return jnp.transpose(x, (0, 2, 1, 3))
96
+ return x
97
+
98
+ # ---------- 前向 + 反向 kernel ----------
99
+
100
+ def _wkv7_kernel(
101
+ w: jnp.ndarray,
102
+ q: jnp.ndarray,
103
+ k: jnp.ndarray,
104
+ v: jnp.ndarray,
105
+ a: jnp.ndarray,
106
+ b: jnp.ndarray,
107
+ h0: jnp.ndarray,
108
+ ):
109
+ """
110
+ 内部 kernel 接口
111
+ 参数顺序与 wkv7_ffi.cc 声明完全一致:
112
+ w,q,k,v,z,a,b -> y,s,sa
113
+ """
114
+ B, T, H, K = q.shape
115
+ dtype = q.dtype
116
+ chunk_num = int(T // CHUNK_LEN)
117
+ out_type = jax.ShapeDtypeStruct((B, T, H, K), dtype)
118
+ s_type = jax.ShapeDtypeStruct((B, H, chunk_num, K, K), jnp.float32)
119
+ sa_type = jax.ShapeDtypeStruct((B, T, H, K), jnp.float32)
120
+
121
+ y, s, sa = jax.ffi.ffi_call(
122
+ "wkv7_fwd", (out_type, s_type, sa_type), vmap_method="broadcast_all"
123
+ )(w, q, k, v, a, b, h0)
124
+
125
+ return y, s, sa
126
+
127
+ @jax.custom_vjp
128
+ def wk7_kernel(
129
+ w: jnp.ndarray,
130
+ q: jnp.ndarray,
131
+ k: jnp.ndarray,
132
+ v: jnp.ndarray,
133
+ a: jnp.ndarray,
134
+ b: jnp.ndarray,
135
+ h0: jnp.ndarray,
136
+ ):
137
+ y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
138
+ finnal_state = s[:, :, -1]
139
+ return (y, jnp.transpose(finnal_state, [0, 1, 3, 2]))
140
+
141
+ # 前向定义
142
+ def _fwd(
143
+ w: jnp.ndarray,
144
+ q: jnp.ndarray,
145
+ k: jnp.ndarray,
146
+ v: jnp.ndarray,
147
+ a: jnp.ndarray,
148
+ b: jnp.ndarray,
149
+ h0: jnp.ndarray,
150
+ ):
151
+ y, s, sa = _wkv7_kernel(w, q, k, v, a, b, h0)
152
+ finnal_state = s[:, :, -1]
153
+ return (y, jnp.transpose(finnal_state, [0, 1, 3, 2])), (w, q, k, v, a, b, s, sa)
154
+
155
+ def _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht):
156
+ dh0_type = jax.ShapeDtypeStruct(dht.shape, dht.dtype)
157
+ dw_type = jax.ShapeDtypeStruct(w.shape, w.dtype)
158
+ dq_type = jax.ShapeDtypeStruct(q.shape, q.dtype)
159
+ dk_type = jax.ShapeDtypeStruct(k.shape, k.dtype)
160
+ dv_type = jax.ShapeDtypeStruct(v.shape, v.dtype)
161
+ da_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
162
+ db_type = jax.ShapeDtypeStruct(b.shape, b.dtype)
163
+
164
+ dh0, dw, dq, dk, dv, da, db = jax.ffi.ffi_call(
165
+ "wkv7_bwd",
166
+ (dh0_type, dw_type, dq_type, dk_type, dv_type, da_type, db_type),
167
+ vmap_method="broadcast_all",
168
+ )(w, q, k, v, a, b, dy, s, sa, dht)
169
+
170
+ return dw, dq, dk, dv, da, db, dh0
171
+
172
+ # 反向定义
173
+ def _bwd(res, grads):
174
+ w, q, k, v, a, b, s, sa = res
175
+ dy, dht = grads
176
+ dy = jnp.asarray(dy, jnp.bfloat16)
177
+ # 调用反向 kernel
178
+ return _wkv7_bwd_kernel(w, q, k, v, a, b, dy, s, sa, dht)
179
+
180
+ wk7_kernel.defvjp(_fwd, _bwd)
181
+
182
+ def generalized_delta_rule(
183
+ r: jnp.ndarray,
184
+ w: jnp.ndarray,
185
+ k: jnp.ndarray,
186
+ v: jnp.ndarray,
187
+ a: jnp.ndarray,
188
+ b: jnp.ndarray,
189
+ initial_state: Optional[jnp.ndarray] = None,
190
+ output_final_state: bool = True,
191
+ head_first: bool = False,
192
+ ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
193
+ """
194
+ 广义 delta 规则,接口与 Torch 实现完全一致
195
+ 参数:
196
+ r,w,k,v,a,b: 输入张量,形状 (B, T, H, K) 或 (B, H, T, K) 当 head_first=True
197
+ initial_state: 可选 (B, H, K, K) 初始状态,None 则零初始化
198
+ output_final_state: 是否同时返回最后状态
199
+ head_first: 是否将 head 维提前
200
+ chunk_len: 必须整除 T,默认 16
201
+ 返回:
202
+ out: (B, T, H, K) 与输入 dtype 一致
203
+ last_state: (B, H, K, K) 当 output_final_state=True
204
+ """
205
+ # 统一转 (B, T, H, K)
206
+ dtype = r.dtype
207
+ r = _transpose_head(r, head_first)
208
+ w = _transpose_head(w, head_first)
209
+ k = _transpose_head(k, head_first)
210
+ v = _transpose_head(v, head_first)
211
+ a = _transpose_head(a, head_first)
212
+ b = _transpose_head(b, head_first)
213
+
214
+ B, T, H, K = r.shape
215
+ if T % CHUNK_LEN:
216
+ raise ValueError(
217
+ f"Sequence length T={T} must be divisible by chunk_len={CHUNK_LEN}"
218
+ )
219
+
220
+ # 处理初始状态
221
+ if initial_state is None:
222
+ h0 = jnp.zeros((B, H, K, K), jnp.float32)
223
+ else:
224
+ h0 = jnp.asarray(initial_state, jnp.float32)
225
+
226
+ # 调用 kernel
227
+
228
+ out, last_state = jax.checkpoint(
229
+ wk7_kernel, policy=cp.save_anything_except_these_names(())
230
+ )(w, r, k, v, a, b, h0)
231
+ out = jnp.asarray(out, dtype) # 保证输出 dtype 与输入一致
232
+
233
+ if output_final_state:
234
+ return out, last_state
235
+ return out
236
+
237
+ return generalized_delta_rule, _wkv7_kernel, _wkv7_bwd_kernel
@@ -14,9 +14,11 @@ from .jax_kernel.chunk_o_fwd import chunk_dplr_fwd_o
14
14
  from .jax_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
15
15
  from .jax_kernel.wy_fast_fwd import prepare_wy_repr_fwd
16
16
  from .jax_kernel.cumsum import chunk_rwkv6_fwd_cumsum
17
- from jax.ad_checkpoint import checkpoint_policies
17
+ from jax.ad_checkpoint import checkpoint_policies as cp
18
+
18
19
  CHUNKSIZE = 16
19
20
 
21
+
20
22
  def chunk_dplr_fwd(
21
23
  q: jax.Array,
22
24
  k: jax.Array,
@@ -307,7 +309,6 @@ def transpose_head(x, head_first):
307
309
  return x
308
310
 
309
311
 
310
- # @partial(jax.jit, static_argnames=['initial_state',"output_final_state","head_first","use_chunk"])
311
312
  def generalized_delta_rule(
312
313
  r: jax.Array,
313
314
  w: jax.Array,
@@ -365,7 +366,9 @@ def generalized_delta_rule(
365
366
  else:
366
367
  assert log_w is not None, "Either w or log_w must be provided!"
367
368
  log_w = transpose_head(log_w, head_first)
368
- o, final_state = chunk_dplr(
369
+ o, final_state = jax.checkpoint(
370
+ chunk_dplr, policy=cp.save_anything_except_these_names(())
371
+ )(
369
372
  r=r,
370
373
  k=k,
371
374
  v=v,
@@ -377,5 +380,3 @@ def generalized_delta_rule(
377
380
  if output_final_state:
378
381
  return jnp.asarray(o, DTYPE), final_state
379
382
  return jnp.asarray(o, DTYPE)
380
-
381
-
@@ -1,4 +1,3 @@
1
- import keras
2
1
  from keras import ops
3
2
 
4
3
 
@@ -62,8 +61,9 @@ def generalized_delta_rule(
62
61
  if ops.shape(state)[0] == 1:
63
62
  state = ops.broadcast_to(state, (B, H, N, N))
64
63
  else:
65
- state = ops.zeros((B, H, N, N), dtype="float32")
66
- out = ops.zeros((B, T, H, N), dtype=r.dtype)
64
+ state = ops.zeros((B, H, N, N))
65
+ state = ops.cast(state, "float32")
66
+ out = ops.zeros((B, T, H, N), DTYPE)
67
67
 
68
68
  def step(t, inputs):
69
69
  """
@@ -83,9 +83,8 @@ def generalized_delta_rule(
83
83
  aa = ops.reshape(a[:, t, :], (B, H, N, 1))
84
84
  bb = ops.reshape(b[:, t, :], (B, H, 1, N))
85
85
  state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
86
- out = ops.slice_update(
87
- out, [0, t, 0, 0], ops.reshape((state @ rr), (B, 1, H, N))
88
- )
86
+ o = ops.cast((state @ rr), out.dtype)
87
+ out = ops.slice_update(out, [0, t, 0, 0], ops.reshape(o, (B, 1, H, N)))
89
88
  return [state, out]
90
89
 
91
90
  state, out = ops.fori_loop(0, T, step, [state, out])
@@ -0,0 +1,123 @@
1
+ """
2
+ TensorFlow 版 generalized_delta_rule
3
+ 前向用 tf.py_function 调 JAX CUDA 内核,反向同样走 JAX。
4
+ 可 @tf.function 编译,可 tf.GradientTape 训练。
5
+ """
6
+
7
+ import tensorflow as tf
8
+ from typing import Optional, Tuple
9
+ import jax.numpy as jnp
10
+ from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
11
+
12
+
13
+ def transpose_head(x, head_first: bool):
14
+ """(B, T, H, K) <-> (B, H, T, K)"""
15
+ x = tf.cast(x, dtype=tf.float32)
16
+ if head_first:
17
+ return tf.transpose(x, (0, 2, 1, 3))
18
+ return x
19
+
20
+
21
+ def get_tf_generalized_delta_rule(HEAD_SIZE=64):
22
+ _, _wkv7_kernel, _wkv7_bwd_kernel = get_jax_generalized_delta_rule(HEAD_SIZE)
23
+
24
+ # ---------- 底层 kernel 包装 ----------
25
+ @tf.py_function(Tout=[tf.bfloat16, tf.float32, tf.float32])
26
+ def _tf_wkv7_fwd(w, q, k, v, a, b, h0):
27
+ """tf.py_function 包装 JAX 前向"""
28
+ y, s, sa = _wkv7_kernel(
29
+ jnp.asarray(w, jnp.bfloat16),
30
+ jnp.asarray(q, jnp.bfloat16),
31
+ jnp.asarray(k, jnp.bfloat16),
32
+ jnp.asarray(v, jnp.bfloat16),
33
+ jnp.asarray(a, jnp.bfloat16),
34
+ jnp.asarray(b, jnp.bfloat16),
35
+ jnp.asarray(h0, jnp.float32),
36
+ )
37
+ return (
38
+ tf.convert_to_tensor(y, tf.bfloat16),
39
+ tf.convert_to_tensor(s, tf.float32),
40
+ tf.convert_to_tensor(sa, tf.float32),
41
+ )
42
+
43
+ @tf.py_function(Tout=[tf.bfloat16] * 6 + [tf.float32])
44
+ def _tf_wkv7_bwd(w, q, k, v, a, b, dy, s, sa, dht):
45
+ """tf.py_function 包装 JAX 反向"""
46
+ dw, dq, dk, dv, da, db, dh0 = _wkv7_bwd_kernel(
47
+ jnp.asarray(w, jnp.bfloat16),
48
+ jnp.asarray(q, jnp.bfloat16),
49
+ jnp.asarray(k, jnp.bfloat16),
50
+ jnp.asarray(v, jnp.bfloat16),
51
+ jnp.asarray(a, jnp.bfloat16),
52
+ jnp.asarray(b, jnp.bfloat16),
53
+ jnp.asarray(dy, jnp.bfloat16),
54
+ jnp.asarray(s, jnp.float32),
55
+ jnp.asarray(sa, jnp.float32),
56
+ jnp.asarray(dht, jnp.bfloat16),
57
+ )
58
+ return tuple(
59
+ tf.convert_to_tensor(g, dtype)
60
+ for g, dtype in zip((dw, dq, dk, dv, da, db), [tf.bfloat16] * 6)
61
+ ) + (tf.convert_to_tensor(dh0, tf.float32),)
62
+
63
+ # ---------- 带梯度的前向 ----------
64
+ @tf.custom_gradient
65
+ def _wk7_tf(w, q, k, v, a, b, h0):
66
+ y, s, sa = _tf_wkv7_fwd(w, q, k, v, a, b, h0)
67
+
68
+ def grad(dy, dht):
69
+ # dy 上层传来的 loss 对 y 的梯度
70
+ # dht 对最后状态的梯度(没有就传 0)
71
+ if dht is None:
72
+ dht = tf.zeros_like(h0)
73
+ grads = _tf_wkv7_bwd(w, q, k, v, a, b, dy, s, sa, dht)
74
+ return grads # (dw, dq, dk, dv, da, db, dh0)
75
+
76
+ final_state = s[:, :, -1] # (B, H, K, K)
77
+ final_state = tf.transpose(final_state, [0, 1, 3, 2]) # 与 JAX 对齐
78
+ return (y, final_state), grad
79
+
80
+ # ---------- 用户接口 ----------
81
+ def generalized_delta_rule(
82
+ r: tf.Tensor, # (B, T, H, K) 或 (B, H, T, K)
83
+ w: tf.Tensor,
84
+ k: tf.Tensor,
85
+ v: tf.Tensor,
86
+ a: tf.Tensor,
87
+ b: tf.Tensor,
88
+ initial_state: Optional[tf.Tensor] = None,
89
+ output_final_state: bool = True,
90
+ head_first: bool = False,
91
+ chunk_len: int = 16,
92
+ ) -> Tuple[tf.Tensor, tf.Tensor]:
93
+ """
94
+ 与 JAX 版接口 1:1 对齐,返回 (out, last_state)
95
+ 可 @tf.function compile,可 tf.GradientTape 训练
96
+ """
97
+ dtype = r.dtype
98
+
99
+ r = transpose_head(r, head_first)
100
+ w = transpose_head(w, head_first)
101
+ k = transpose_head(k, head_first)
102
+ v = transpose_head(v, head_first)
103
+ a = transpose_head(a, head_first)
104
+ b = transpose_head(b, head_first)
105
+
106
+ B, T, H, K = tf.unstack(tf.shape(r), num=4)
107
+ if T % chunk_len != 0:
108
+ raise ValueError(f"T={T} must be divisible by chunk_len={chunk_len}")
109
+
110
+ if initial_state is None:
111
+ h0 = tf.zeros([B, H, K, K], dtype=tf.float32)
112
+ else:
113
+ h0 = tf.cast(initial_state, tf.float32)
114
+
115
+ # 带梯度前向
116
+ out, last_state = _wk7_tf(w, r, k, v, a, b, h0)
117
+
118
+ # 转回用户期望 dtype
119
+ out = tf.cast(out, dtype)
120
+
121
+ return (out, last_state) if output_final_state else out
122
+
123
+ return generalized_delta_rule, _tf_wkv7_fwd, _tf_wkv7_bwd
@@ -0,0 +1,165 @@
1
+ #include <cuda_bf16.h>
2
+ #include <assert.h>
3
+
4
+ using bf = __nv_bfloat16;
5
+
6
+ __device__ inline float to_float(const bf & u) {
7
+ return __bfloat162float(u);
8
+ }
9
+
10
+ __device__ inline bf to_bf(const float & u) {
11
+ return __float2bfloat16_rn(u);
12
+ }
13
+ typedef bf * __restrict__ F_;
14
+
15
+ __global__ void forward_kernel(int T, int H,
16
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
17
+ bf* y_, float* s_, float* sa_, float* h0_) {
18
+ constexpr int C = _C_;
19
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
20
+ float state[C] = {0};
21
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
22
+ int h0_base =( (bb*H + hh)*C + i)*C;
23
+ #pragma unroll
24
+ for (int j = 0; j < C; j++) {
25
+ state[j] = h0_[h0_base + j];
26
+ }
27
+ for (int t = 0; t < T; t++) {
28
+ int ind = bb*T*H*C + t*H*C + hh * C + i;
29
+ __syncthreads();
30
+ q[i] = to_float(q_[ind]);
31
+ w[i] = __expf(-__expf(to_float(w_[ind])));
32
+ k[i] = to_float(k_[ind]);
33
+ a[i] = to_float(a_[ind]);
34
+ b[i] = to_float(b_[ind]);
35
+ __syncthreads();
36
+ float sa = 0;
37
+ #pragma unroll
38
+ for (int j = 0; j < C; j++) {
39
+ sa += a[j] * state[j];
40
+ }
41
+ sa_[ind] = sa;
42
+ float v = to_float(v_[ind]);
43
+ float y = 0;
44
+ #pragma unroll
45
+ for (int j = 0; j < C; j++) {
46
+ float& s = state[j];
47
+ s = s * w[j] + sa * b[j] + k[j] * v;
48
+ y += s * q[j];
49
+ }
50
+ y_[ind] = to_bf(y);
51
+
52
+ if ((t+1)%_CHUNK_LEN_ == 0) {
53
+ int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
54
+ #pragma unroll
55
+ for (int j = 0; j < C; j++) {
56
+ s_[base + j*C] = state[j];
57
+ }
58
+ }
59
+ }
60
+ }
61
+ __global__ void backward_kernel(int T, int H,
62
+ F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
63
+ float * __restrict__ s_, float * __restrict__ sa_,
64
+ float * __restrict__ dht_,float * __restrict__ dh0_,
65
+ bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
66
+ constexpr int C = _C_;
67
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
68
+ float stateT[C] = {
69
+ 0
70
+ }
71
+ , dstate[C] = {
72
+ 0
73
+ }
74
+ , dstateT[C] = {
75
+ 0
76
+ }
77
+ ;
78
+ int dht_base =( (bb*H + hh)*C + i)*C;
79
+ #pragma unroll
80
+ for (int j = 0; j < C; j++) {
81
+ dstate[j] = dht_[dht_base + j];
82
+ dstateT[j] = dht_[dht_base + j];
83
+ }
84
+ __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
85
+ float qi, wi, ki, ai, bi, dyi;
86
+
87
+ for (int t = T-1; t >= 0; t--) {
88
+ int ind = bb*T*H*C + t*H*C + hh * C + i;
89
+ __syncthreads();
90
+ q[i] = qi = to_float(q_[ind]);
91
+ float wi_fac = -__expf(to_float(w_[ind]));
92
+ w[i] = wi = __expf(wi_fac);
93
+ k[i] = ki = to_float(k_[ind]);
94
+ a[i] = ai = to_float(a_[ind]);
95
+ b[i] = bi = to_float(b_[ind]);
96
+ v[i] = to_float(v_[ind]);
97
+ dy[i] = dyi = to_float(dy_[ind]);
98
+ sa[i] = sa_[ind];
99
+ __syncthreads();
100
+
101
+ if ((t+1)%_CHUNK_LEN_ == 0) {
102
+ int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
103
+ #pragma unroll
104
+ for (int j = 0; j < C; j++) {
105
+ stateT[j] = s_[base + j];
106
+ }
107
+ }
108
+ float dq = 0;
109
+ #pragma unroll
110
+ for (int j = 0; j < C; j++) {
111
+ dq += stateT[j]*dy[j];
112
+ }
113
+ dq_[ind] = to_bf(dq);
114
+ float iwi = 1.0f/(wi+0.000001f);
115
+ #pragma unroll
116
+ for (int j = 0; j < C; j++) {
117
+ stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
118
+ dstate[j] += dyi * q[j];
119
+ dstateT[j] += qi * dy[j];
120
+ }
121
+ float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
122
+ #pragma unroll
123
+ for (int j = 0; j < C; j++) {
124
+ dw += dstateT[j]*stateT[j];
125
+ dk += dstateT[j]*v[j];
126
+ dv += dstate[j]*k[j];
127
+ dSb += dstate[j]*b[j];
128
+ db += dstateT[j]*sa[j];
129
+ }
130
+ dw_[ind] = to_bf(dw * wi * wi_fac);
131
+ dk_[ind] = to_bf(dk);
132
+ dv_[ind] = to_bf(dv);
133
+ db_[ind] = to_bf(db);
134
+ __syncthreads();
135
+ dSb_shared[i] = dSb;
136
+ __syncthreads();
137
+ float da = 0;
138
+ #pragma unroll
139
+ for (int j = 0; j < C; j++) {
140
+ da += stateT[j]*dSb_shared[j];
141
+ }
142
+ da_[ind] = to_bf(da);
143
+ #pragma unroll
144
+ for (int j = 0; j < C; j++) {
145
+ dstate[j] = dstate[j]*w[j] + dSb * a[j];
146
+ dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
147
+ if (t==0){
148
+ dh0_[dht_base + j] = dstate[j];
149
+ }
150
+ }
151
+ }
152
+ }
153
+
154
+ void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa, float* h0) {
155
+ forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa,h0);
156
+ }
157
+
158
+ void cuda_backward(int B, int T, int H,
159
+ bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy,
160
+ float*s, float*sa,float*dht,float*dh0,
161
+ bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da
162
+ ) {
163
+ assert(T%_CHUNK_LEN_ == 0);
164
+ backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dht,dh0,dw,dq,dk,dv,dz,da);
165
+ }
@@ -0,0 +1,35 @@
1
+ #include <torch/extension.h>
2
+ #include <cuda_bf16.h>
3
+
4
+ using bf = __nv_bfloat16;
5
+ void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa,float*h0);
6
+
7
+ void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v,
8
+ torch::Tensor &z, torch::Tensor &a,
9
+ torch::Tensor &y,
10
+ torch::Tensor &s, torch::Tensor &sa,torch::Tensor &h0) {
11
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
12
+ cuda_forward(B, T, H,
13
+ (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(),
14
+ (float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)h0.data_ptr());
15
+ }
16
+ void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa,float*dht,float*dh0, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
17
+ void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
18
+ torch::Tensor &s, torch::Tensor &sa,torch::Tensor &dht,torch::Tensor &dh0,
19
+ torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
20
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
21
+ cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(),
22
+ (bf*)dy.data_ptr(),
23
+ (float*)s.data_ptr(), (float*)sa.data_ptr(),(float*)dht.data_ptr(),(float*)dh0.data_ptr(),
24
+ (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
25
+ }
26
+
27
+ TORCH_LIBRARY(wind_backstepping, m) {
28
+ m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa, Tensor(f!) h0) -> ()");
29
+ m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa,Tensor dht,Tensor(a!) dh0, Tensor(b!) dw, Tensor(c!) dq, Tensor(d!) dk, Tensor(e!) dv, Tensor(f!) dz, Tensor(g!) da) -> ()");
30
+ }
31
+
32
+ TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
33
+ m.impl("forward", &forward);
34
+ m.impl("backward", &backward);
35
+ }