rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX 版 RWKV7 单步 wkv kernel(仅前向传播)
|
|
3
|
+
延迟编译 CUDA 扩展,专为 T=1 场景优化
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
import pathlib
|
|
8
|
+
import subprocess
|
|
9
|
+
import ctypes
|
|
10
|
+
import jax
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from typing import Optional, Tuple, Union
|
|
13
|
+
|
|
14
|
+
# ---------- 延迟编译(改到当前目录) ----------
|
|
15
|
+
_CURRENT_DIR = pathlib.Path(__file__).parent.absolute()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_jax_generalized_delta_rule_single_step(HEAD_SIZE=64):
|
|
19
|
+
_BUILD_DIR = _CURRENT_DIR / f"build_single_step_{HEAD_SIZE}"
|
|
20
|
+
_SO_PATH = _CURRENT_DIR / f"build_single_step_{HEAD_SIZE}/wkv7_single_step.so"
|
|
21
|
+
|
|
22
|
+
def _ensure_compiled() -> pathlib.Path:
|
|
23
|
+
"""首次调用时编译 CUDA 扩展,产出放在当前源码目录"""
|
|
24
|
+
if _SO_PATH.exists():
|
|
25
|
+
return _SO_PATH
|
|
26
|
+
|
|
27
|
+
print("[rwkv7_single_step_jax] First use – compiling CUDA kernel…")
|
|
28
|
+
src_dir = _CURRENT_DIR
|
|
29
|
+
build_dir = _BUILD_DIR
|
|
30
|
+
build_dir.mkdir(exist_ok=True)
|
|
31
|
+
|
|
32
|
+
# 获取 XLA 头文件路径
|
|
33
|
+
xla_include_dir = jax.ffi.include_dir()
|
|
34
|
+
if not xla_include_dir:
|
|
35
|
+
raise RuntimeError("jax.ffi.include_dir() 返回空,请检查 JAX >= 0.4.31")
|
|
36
|
+
|
|
37
|
+
# CUDA 编译 flags(移除 CHUNK_LEN 定义)
|
|
38
|
+
cuda_flags = [
|
|
39
|
+
"-ftz=true",
|
|
40
|
+
"-prec-div=false",
|
|
41
|
+
"-prec-sqrt=false",
|
|
42
|
+
"--use_fast_math",
|
|
43
|
+
"-O3",
|
|
44
|
+
"-Xptxas=-O3",
|
|
45
|
+
"-res-usage",
|
|
46
|
+
"--extra-device-vectorization",
|
|
47
|
+
f"-D_C_={HEAD_SIZE}",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
# CMake 配置
|
|
51
|
+
cmake_args = [
|
|
52
|
+
"cmake",
|
|
53
|
+
"-S",
|
|
54
|
+
str(src_dir),
|
|
55
|
+
"-B",
|
|
56
|
+
str(build_dir),
|
|
57
|
+
"-DCMAKE_BUILD_TYPE=Release",
|
|
58
|
+
f"-DCMAKE_INSTALL_PREFIX={_CURRENT_DIR}",
|
|
59
|
+
f"-DXLA_INCLUDE_DIR={xla_include_dir}",
|
|
60
|
+
f"-DCMAKE_CUDA_FLAGS={' '.join(cuda_flags)}",
|
|
61
|
+
]
|
|
62
|
+
subprocess.check_call(cmake_args)
|
|
63
|
+
|
|
64
|
+
# 构建与安装
|
|
65
|
+
subprocess.check_call(["cmake", "--build", str(build_dir), "-j"])
|
|
66
|
+
subprocess.check_call(["cmake", "--install", str(build_dir)])
|
|
67
|
+
|
|
68
|
+
if not _SO_PATH.exists():
|
|
69
|
+
raise RuntimeError("Compilation failed – wkv7_single_step.so not found.")
|
|
70
|
+
|
|
71
|
+
print("[rwkv7_single_step_jax] Compilation finished – output at", _SO_PATH)
|
|
72
|
+
return _SO_PATH
|
|
73
|
+
|
|
74
|
+
# 注册 FFI 符号(仅前向)
|
|
75
|
+
_lib = ctypes.CDLL(_ensure_compiled())
|
|
76
|
+
jax.ffi.register_ffi_target(
|
|
77
|
+
"wkv7_single_step_fwd",
|
|
78
|
+
jax.ffi.pycapsule(_lib.Wkv7SingleStepFwd),
|
|
79
|
+
platform="CUDA",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# ---------- 工具 ----------
|
|
83
|
+
def _transpose_head(x: jnp.ndarray, head_first: bool) -> jnp.ndarray:
|
|
84
|
+
"""(B, 1, H, K) <-> (B, H, 1, K)"""
|
|
85
|
+
x = jnp.asarray(x, dtype=jnp.bfloat16)
|
|
86
|
+
if head_first:
|
|
87
|
+
return jnp.transpose(x, (0, 2, 1, 3))
|
|
88
|
+
return x
|
|
89
|
+
|
|
90
|
+
# ---------- 前向 kernel ----------
|
|
91
|
+
def _wkv7_single_step_kernel(
|
|
92
|
+
w: jnp.ndarray,
|
|
93
|
+
q: jnp.ndarray,
|
|
94
|
+
k: jnp.ndarray,
|
|
95
|
+
v: jnp.ndarray,
|
|
96
|
+
a: jnp.ndarray,
|
|
97
|
+
b: jnp.ndarray,
|
|
98
|
+
h0: jnp.ndarray,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
内部 kernel 接口
|
|
102
|
+
参数: w,q,k,v,a,b,h0 -> y,s
|
|
103
|
+
"""
|
|
104
|
+
B, H, K = q.shape
|
|
105
|
+
dtype = q.dtype
|
|
106
|
+
|
|
107
|
+
out_type = jax.ShapeDtypeStruct((B, H, K), dtype)
|
|
108
|
+
s_type = jax.ShapeDtypeStruct((B, H, K, K), jnp.float32)
|
|
109
|
+
|
|
110
|
+
y, s = jax.ffi.ffi_call(
|
|
111
|
+
"wkv7_single_step_fwd", (out_type, s_type), vmap_method="broadcast_all"
|
|
112
|
+
)(w, q, k, v, a, b, h0)
|
|
113
|
+
|
|
114
|
+
return y, s
|
|
115
|
+
|
|
116
|
+
def wk7_single_step_kernel(
|
|
117
|
+
w: jnp.ndarray,
|
|
118
|
+
q: jnp.ndarray,
|
|
119
|
+
k: jnp.ndarray,
|
|
120
|
+
v: jnp.ndarray,
|
|
121
|
+
a: jnp.ndarray,
|
|
122
|
+
b: jnp.ndarray,
|
|
123
|
+
h0: jnp.ndarray,
|
|
124
|
+
):
|
|
125
|
+
"""前向计算函数"""
|
|
126
|
+
y, s = _wkv7_single_step_kernel(w, q, k, v, a, b, h0)
|
|
127
|
+
final_state = s # 单步后直接返回状态
|
|
128
|
+
return (y, final_state)
|
|
129
|
+
|
|
130
|
+
# ---------- 主接口 ----------
|
|
131
|
+
def generalized_delta_rule_single_step(
|
|
132
|
+
r: jnp.ndarray,
|
|
133
|
+
w: jnp.ndarray,
|
|
134
|
+
k: jnp.ndarray,
|
|
135
|
+
v: jnp.ndarray,
|
|
136
|
+
a: jnp.ndarray,
|
|
137
|
+
b: jnp.ndarray,
|
|
138
|
+
initial_state: Optional[jnp.ndarray] = None,
|
|
139
|
+
output_final_state: bool = True,
|
|
140
|
+
head_first: bool = False,
|
|
141
|
+
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
|
|
142
|
+
"""
|
|
143
|
+
单步广义 delta 规则(仅前向)
|
|
144
|
+
参数:
|
|
145
|
+
r,w,k,v,a,b: 输入张量,形状必须为 (B, 1, H, K) 或 (B, H, 1, K)
|
|
146
|
+
initial_state: 可选 (B, H, K, K) 初始状态,None 则零初始化
|
|
147
|
+
output_final_state: 是否同时返回最后状态
|
|
148
|
+
head_first: 是否将 head 维提前
|
|
149
|
+
返回:
|
|
150
|
+
out: (B, 1, H, K) 与输入 dtype 一致
|
|
151
|
+
last_state: (B, H, K, K) 当 output_final_state=True
|
|
152
|
+
"""
|
|
153
|
+
# 统一转 (B, 1, H, K) 并验证 T=1
|
|
154
|
+
r = _transpose_head(r, head_first)
|
|
155
|
+
w = _transpose_head(w, head_first)
|
|
156
|
+
k = _transpose_head(k, head_first)
|
|
157
|
+
v = _transpose_head(v, head_first)
|
|
158
|
+
a = _transpose_head(a, head_first)
|
|
159
|
+
b = _transpose_head(b, head_first)
|
|
160
|
+
|
|
161
|
+
B, T, H, K = r.shape
|
|
162
|
+
if T != 1:
|
|
163
|
+
raise ValueError(f"Single-step kernel requires T=1, but got T={T}.")
|
|
164
|
+
|
|
165
|
+
# 处理初始状态
|
|
166
|
+
if initial_state is None:
|
|
167
|
+
h0 = jnp.zeros((B, H, K, K), jnp.float32)
|
|
168
|
+
else:
|
|
169
|
+
h0 = jnp.asarray(initial_state, jnp.float32)
|
|
170
|
+
|
|
171
|
+
# 移除 T 维度后调用 kernel
|
|
172
|
+
r = r[:, 0, :, :] # (B, H, K)
|
|
173
|
+
w = w[:, 0, :, :]
|
|
174
|
+
k = k[:, 0, :, :]
|
|
175
|
+
v = v[:, 0, :, :]
|
|
176
|
+
a = a[:, 0, :, :]
|
|
177
|
+
b = b[:, 0, :, :]
|
|
178
|
+
|
|
179
|
+
# 调用前向 kernel
|
|
180
|
+
out, last_state = wk7_single_step_kernel(w, r, k, v, a, b, h0)
|
|
181
|
+
|
|
182
|
+
# 恢复 T 维度
|
|
183
|
+
out = jnp.expand_dims(out, axis=1) # (B, 1, H, K)
|
|
184
|
+
out = jnp.asarray(out, r.dtype)
|
|
185
|
+
|
|
186
|
+
if output_final_state:
|
|
187
|
+
return out, last_state
|
|
188
|
+
return out
|
|
189
|
+
|
|
190
|
+
return generalized_delta_rule_single_step
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from ..jax_kernel.chunk_A_fwd import *
|
|
2
|
+
from ..jax_kernel.chunk_A_bwd import *
|
|
3
|
+
from ..jax_kernel.chunk_h_fwd import *
|
|
4
|
+
from ..jax_kernel.chunk_h_bwd import *
|
|
5
|
+
from ..jax_kernel.chunk_o_fwd import *
|
|
6
|
+
from ..jax_kernel.chunk_o_bwd import *
|
|
7
|
+
from ..jax_kernel.cumsum import *
|
|
8
|
+
from ..jax_kernel.wy_fast_fwd import *
|
|
9
|
+
from ..jax_kernel.wy_fast_bwd import *
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025,Qingwen Lin
|
|
3
|
+
|
|
4
|
+
import jax_triton as jt
|
|
5
|
+
import jax
|
|
6
|
+
import triton
|
|
7
|
+
from ..triton_kernel.chunk_A_bwd import *
|
|
8
|
+
from ..triton_kernel.utils import is_gather_supported
|
|
9
|
+
from ..get_torch_devices_info import check_shared_mem
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def chunk_dplr_bwd_dqk_intra(
|
|
13
|
+
q: jax.Array,
|
|
14
|
+
k: jax.Array,
|
|
15
|
+
a: jax.Array,
|
|
16
|
+
b: jax.Array,
|
|
17
|
+
gi: jax.Array,
|
|
18
|
+
ge: jax.Array,
|
|
19
|
+
dAqk: jax.Array,
|
|
20
|
+
dAqb: jax.Array,
|
|
21
|
+
dAak: jax.Array,
|
|
22
|
+
dAab: jax.Array,
|
|
23
|
+
dqg: jax.Array,
|
|
24
|
+
dkg: jax.Array,
|
|
25
|
+
dag: jax.Array,
|
|
26
|
+
dbg: jax.Array,
|
|
27
|
+
dgk_last: jax.Array,
|
|
28
|
+
scale: float = 1.0,
|
|
29
|
+
chunk_size: int = 16,
|
|
30
|
+
):
|
|
31
|
+
B, T, H, K = q.shape
|
|
32
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
33
|
+
BK = (
|
|
34
|
+
min(64, triton.next_power_of_2(K))
|
|
35
|
+
if check_shared_mem()
|
|
36
|
+
else min(32, triton.next_power_of_2(K))
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
NT = triton.cdiv(T, BT)
|
|
40
|
+
NK = triton.cdiv(K, BK)
|
|
41
|
+
grid = (NK, NT, B * H)
|
|
42
|
+
|
|
43
|
+
out_shapes = [
|
|
44
|
+
jax.ShapeDtypeStruct(q.shape, q.dtype),
|
|
45
|
+
jax.ShapeDtypeStruct(k.shape, k.dtype),
|
|
46
|
+
jax.ShapeDtypeStruct(a.shape, a.dtype),
|
|
47
|
+
jax.ShapeDtypeStruct(b.shape, b.dtype),
|
|
48
|
+
jax.ShapeDtypeStruct(gi.shape, "float32"),
|
|
49
|
+
jax.ShapeDtypeStruct(gi.shape, "float32"),
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
dq, dk, da, db, dgk, dgk_offset = jt.triton_call(
|
|
53
|
+
q,
|
|
54
|
+
k,
|
|
55
|
+
a,
|
|
56
|
+
b,
|
|
57
|
+
gi,
|
|
58
|
+
ge,
|
|
59
|
+
dAqk,
|
|
60
|
+
dAqb,
|
|
61
|
+
dAak,
|
|
62
|
+
dAab,
|
|
63
|
+
dqg,
|
|
64
|
+
dkg,
|
|
65
|
+
dag,
|
|
66
|
+
dbg,
|
|
67
|
+
T,
|
|
68
|
+
scale=scale,
|
|
69
|
+
H=H,
|
|
70
|
+
K=K,
|
|
71
|
+
BT=BT,
|
|
72
|
+
BC=BT,
|
|
73
|
+
BK=BK,
|
|
74
|
+
GATHER_SUPPORTED=is_gather_supported,
|
|
75
|
+
kernel=chunk_dplr_bwd_kernel_intra,
|
|
76
|
+
out_shape=out_shapes,
|
|
77
|
+
grid=grid,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def grid(meta):
|
|
81
|
+
return (NT, triton.cdiv(K, meta["BK"]), B * H)
|
|
82
|
+
|
|
83
|
+
dgk_output = jt.triton_call(
|
|
84
|
+
dgk,
|
|
85
|
+
dgk_offset,
|
|
86
|
+
dgk_last,
|
|
87
|
+
T,
|
|
88
|
+
H=H,
|
|
89
|
+
K=K,
|
|
90
|
+
BT=BT,
|
|
91
|
+
kernel=chunk_dplr_bwd_dgk_kernel,
|
|
92
|
+
out_shape=[jax.ShapeDtypeStruct(dgk.shape, "float32")],
|
|
93
|
+
grid=grid,
|
|
94
|
+
)[0]
|
|
95
|
+
return dq, dk, da, db, dgk_output
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025,Qingwen Lin
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import jax_triton as jt
|
|
6
|
+
import jax
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..triton_kernel.utils import is_gather_supported
|
|
10
|
+
|
|
11
|
+
from ..triton_kernel.chunk_A_fwd import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def chunk_dplr_fwd_intra(
|
|
15
|
+
q: jax.Array,
|
|
16
|
+
k: jax.Array,
|
|
17
|
+
a: jax.Array,
|
|
18
|
+
b: jax.Array,
|
|
19
|
+
gi: jax.Array,
|
|
20
|
+
ge: jax.Array,
|
|
21
|
+
scale: float,
|
|
22
|
+
chunk_size: int,
|
|
23
|
+
):
|
|
24
|
+
B, T, H, K = k.shape
|
|
25
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
26
|
+
|
|
27
|
+
NT = triton.cdiv(T, BT)
|
|
28
|
+
shape = [B, T, H, BT]
|
|
29
|
+
out_shapes = [
|
|
30
|
+
jax.ShapeDtypeStruct(q.shape, q.dtype),
|
|
31
|
+
jax.ShapeDtypeStruct(k.shape, q.dtype),
|
|
32
|
+
jax.ShapeDtypeStruct(a.shape, q.dtype),
|
|
33
|
+
jax.ShapeDtypeStruct(b.shape, q.dtype),
|
|
34
|
+
jax.ShapeDtypeStruct(shape, q.dtype),
|
|
35
|
+
jax.ShapeDtypeStruct(shape, q.dtype),
|
|
36
|
+
jax.ShapeDtypeStruct(shape, "float32"),
|
|
37
|
+
jax.ShapeDtypeStruct(shape, "float32"),
|
|
38
|
+
]
|
|
39
|
+
grid = (NT, B, H)
|
|
40
|
+
BK = triton.next_power_of_2(K)
|
|
41
|
+
qg, kg, ag, bg, Aqk, Aqb, Aab, Aak = jt.triton_call(
|
|
42
|
+
q,
|
|
43
|
+
k,
|
|
44
|
+
a,
|
|
45
|
+
b,
|
|
46
|
+
gi,
|
|
47
|
+
ge,
|
|
48
|
+
T,
|
|
49
|
+
scale=scale,
|
|
50
|
+
H=H,
|
|
51
|
+
K=K,
|
|
52
|
+
BT=BT,
|
|
53
|
+
BC=BT,
|
|
54
|
+
BK=BK,
|
|
55
|
+
GATHER_SUPPORTED=is_gather_supported,
|
|
56
|
+
kernel=chunk_dplr_fwd_A_kernel_intra_sub_intra,
|
|
57
|
+
out_shape=out_shapes,
|
|
58
|
+
grid=grid,
|
|
59
|
+
)
|
|
60
|
+
return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025,Qingwen Lin
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import jax_triton as jt
|
|
7
|
+
import jax
|
|
8
|
+
import triton
|
|
9
|
+
|
|
10
|
+
from ..get_jax_devices_info import check_shared_mem
|
|
11
|
+
from ..triton_kernel.chunk_h_bwd import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def chunk_dplr_bwd_dhu(
|
|
15
|
+
qg: jax.Array,
|
|
16
|
+
bg: jax.Array,
|
|
17
|
+
w: jax.Array,
|
|
18
|
+
gk: jax.Array,
|
|
19
|
+
h0: jax.Array,
|
|
20
|
+
dht: Optional[jax.Array],
|
|
21
|
+
do: jax.Array,
|
|
22
|
+
dv: jax.Array,
|
|
23
|
+
chunk_size: int = 64,
|
|
24
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array]:
|
|
25
|
+
B, T, H, K, V = *qg.shape, do.shape[-1]
|
|
26
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
27
|
+
BK = triton.next_power_of_2(K)
|
|
28
|
+
assert BK <= 256, (
|
|
29
|
+
"current kernel does not support head dimension being larger than 256."
|
|
30
|
+
)
|
|
31
|
+
# H100
|
|
32
|
+
if check_shared_mem("hopper"):
|
|
33
|
+
BV = 64
|
|
34
|
+
BC = 64 if K <= 128 else 32
|
|
35
|
+
elif check_shared_mem("ampere"): # A100
|
|
36
|
+
BV = 32
|
|
37
|
+
BC = 32
|
|
38
|
+
else: # Etc: 4090
|
|
39
|
+
BV = 16
|
|
40
|
+
BC = 16
|
|
41
|
+
|
|
42
|
+
N, NT = B, triton.cdiv(T, BT)
|
|
43
|
+
BC = min(BT, BC)
|
|
44
|
+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
|
45
|
+
assert NK == 1, (
|
|
46
|
+
"NK > 1 is not supported because it involves time-consuming synchronization"
|
|
47
|
+
)
|
|
48
|
+
dh_shape = (B, NT, H, K, V)
|
|
49
|
+
out_shapes = [
|
|
50
|
+
jax.ShapeDtypeStruct(dh_shape, dv.dtype),
|
|
51
|
+
jax.ShapeDtypeStruct((B, H, K, V), "float32"),
|
|
52
|
+
jax.ShapeDtypeStruct(dv.shape, dv.dtype),
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
grid = (NK, NV, N * H)
|
|
56
|
+
dh, dh0, dv2 = jt.triton_call(
|
|
57
|
+
qg,
|
|
58
|
+
bg,
|
|
59
|
+
w,
|
|
60
|
+
gk,
|
|
61
|
+
dht,
|
|
62
|
+
dv,
|
|
63
|
+
do,
|
|
64
|
+
T,
|
|
65
|
+
H=H,
|
|
66
|
+
K=K,
|
|
67
|
+
V=V,
|
|
68
|
+
BT=BT,
|
|
69
|
+
BC=BC,
|
|
70
|
+
BK=BK,
|
|
71
|
+
BV=BV,
|
|
72
|
+
kernel=chunk_dplr_bwd_kernel_dhu.fn,
|
|
73
|
+
out_shape=out_shapes,
|
|
74
|
+
grid=grid,
|
|
75
|
+
USE_FINAL_STATE_GRADIENT=dht is not None,
|
|
76
|
+
USE_INITIAL_STATE=h0 is not None,
|
|
77
|
+
)
|
|
78
|
+
return dh, dh0, dv2
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025,Qingwen Lin
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import jax_triton as jt
|
|
7
|
+
import jax
|
|
8
|
+
import triton
|
|
9
|
+
|
|
10
|
+
from ..get_jax_devices_info import check_shared_mem
|
|
11
|
+
from ..triton_kernel.chunk_h_fwd import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def chunk_dplr_fwd_h(
|
|
15
|
+
kg: jax.Array,
|
|
16
|
+
v: jax.Array,
|
|
17
|
+
w: jax.Array,
|
|
18
|
+
u: jax.Array,
|
|
19
|
+
bg: jax.Array,
|
|
20
|
+
gk: jax.Array,
|
|
21
|
+
initial_state: Optional[jax.Array] = None,
|
|
22
|
+
output_final_state: bool = False,
|
|
23
|
+
chunk_size: int = 64,
|
|
24
|
+
) -> Tuple[jax.Array, jax.Array]:
|
|
25
|
+
B, T, H, K, V = *kg.shape, u.shape[-1]
|
|
26
|
+
BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
|
|
27
|
+
|
|
28
|
+
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
|
29
|
+
BK = triton.next_power_of_2(K)
|
|
30
|
+
assert BK <= 256, "current kernel does not support head dimension larger than 256."
|
|
31
|
+
# H100 can have larger block size
|
|
32
|
+
|
|
33
|
+
if check_shared_mem("hopper"):
|
|
34
|
+
BV = 64
|
|
35
|
+
BC = 64 if K <= 128 else 32
|
|
36
|
+
elif check_shared_mem("ampere"): # A100
|
|
37
|
+
BV = 32
|
|
38
|
+
BC = 32
|
|
39
|
+
else:
|
|
40
|
+
BV = 16
|
|
41
|
+
BC = 16
|
|
42
|
+
|
|
43
|
+
BC = min(BT, BC)
|
|
44
|
+
NK = triton.cdiv(K, BK)
|
|
45
|
+
NV = triton.cdiv(V, BV)
|
|
46
|
+
assert NK == 1, (
|
|
47
|
+
"NK > 1 is not supported because it involves time-consuming synchronization"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
out_shapes = [
|
|
51
|
+
jax.ShapeDtypeStruct((B, NT, H, K, V), kg.dtype),
|
|
52
|
+
jax.ShapeDtypeStruct([N, H, K, V], "float32"),
|
|
53
|
+
jax.ShapeDtypeStruct(u.shape, u.dtype),
|
|
54
|
+
]
|
|
55
|
+
grid = (NK, NV, N * H)
|
|
56
|
+
if initial_state is None:
|
|
57
|
+
initial_state = jax.numpy.zeros([N, H, K, V], "float32")
|
|
58
|
+
h, final_state, v_new = jt.triton_call(
|
|
59
|
+
kg,
|
|
60
|
+
v,
|
|
61
|
+
w,
|
|
62
|
+
bg,
|
|
63
|
+
u,
|
|
64
|
+
gk,
|
|
65
|
+
initial_state,
|
|
66
|
+
T,
|
|
67
|
+
H=H,
|
|
68
|
+
K=K,
|
|
69
|
+
V=V,
|
|
70
|
+
BT=BT,
|
|
71
|
+
BC=BC,
|
|
72
|
+
BK=BK,
|
|
73
|
+
BV=BV,
|
|
74
|
+
kernel=chunk_dplr_fwd_kernel_h.fn,
|
|
75
|
+
out_shape=out_shapes,
|
|
76
|
+
grid=grid,
|
|
77
|
+
STORE_FINAL_STATE=True,
|
|
78
|
+
USE_INITIAL_STATE=True,
|
|
79
|
+
)
|
|
80
|
+
return h, v_new, final_state
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
3
|
+
|
|
4
|
+
from typing import Tuple
|
|
5
|
+
|
|
6
|
+
import jax_triton as jt
|
|
7
|
+
import jax
|
|
8
|
+
import triton
|
|
9
|
+
|
|
10
|
+
from ..get_jax_devices_info import check_shared_mem
|
|
11
|
+
from ..triton_kernel.chunk_o_bwd import *
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def chunk_dplr_bwd_dv(
|
|
15
|
+
A_qk: jax.Array,
|
|
16
|
+
kg: jax.Array,
|
|
17
|
+
do: jax.Array,
|
|
18
|
+
dh: jax.Array,
|
|
19
|
+
chunk_size: int = 64,
|
|
20
|
+
) -> jax.Array:
|
|
21
|
+
B, T, H, K, V = *kg.shape, do.shape[-1]
|
|
22
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
23
|
+
|
|
24
|
+
NT = triton.cdiv(T, BT)
|
|
25
|
+
|
|
26
|
+
def grid(meta):
|
|
27
|
+
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
|
28
|
+
|
|
29
|
+
dv = jt.triton_call(
|
|
30
|
+
A_qk,
|
|
31
|
+
kg,
|
|
32
|
+
do,
|
|
33
|
+
dh,
|
|
34
|
+
T,
|
|
35
|
+
H=H,
|
|
36
|
+
K=K,
|
|
37
|
+
V=V,
|
|
38
|
+
BT=BT,
|
|
39
|
+
kernel=chunk_dplr_bwd_kernel_dv,
|
|
40
|
+
out_shape=jax.ShapeDtypeStruct(do.shape, do.dtype),
|
|
41
|
+
grid=grid,
|
|
42
|
+
)
|
|
43
|
+
return dv
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def chunk_dplr_bwd_o(
|
|
47
|
+
k: jax.Array,
|
|
48
|
+
b: jax.Array,
|
|
49
|
+
v: jax.Array,
|
|
50
|
+
v_new: jax.Array,
|
|
51
|
+
gk: jax.Array,
|
|
52
|
+
do: jax.Array,
|
|
53
|
+
h: jax.Array,
|
|
54
|
+
dh: jax.Array,
|
|
55
|
+
dv: jax.Array,
|
|
56
|
+
w: jax.Array,
|
|
57
|
+
chunk_size: int = 64,
|
|
58
|
+
scale: float = 1.0,
|
|
59
|
+
) -> Tuple[jax.Array, jax.Array, jax.Array]:
|
|
60
|
+
B, T, H, K, V = *w.shape, v.shape[-1]
|
|
61
|
+
|
|
62
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
63
|
+
NT = triton.cdiv(T, BT)
|
|
64
|
+
|
|
65
|
+
BK = (
|
|
66
|
+
min(triton.next_power_of_2(K), 64)
|
|
67
|
+
if check_shared_mem()
|
|
68
|
+
else min(triton.next_power_of_2(K), 32)
|
|
69
|
+
)
|
|
70
|
+
BV = (
|
|
71
|
+
min(triton.next_power_of_2(V), 64)
|
|
72
|
+
if check_shared_mem()
|
|
73
|
+
else min(triton.next_power_of_2(K), 32)
|
|
74
|
+
)
|
|
75
|
+
NK = triton.cdiv(K, BK)
|
|
76
|
+
grid = (NK, NT, B * H)
|
|
77
|
+
|
|
78
|
+
out_shapes = [
|
|
79
|
+
jax.ShapeDtypeStruct(k.shape, k.dtype),
|
|
80
|
+
jax.ShapeDtypeStruct(k.shape, k.dtype),
|
|
81
|
+
jax.ShapeDtypeStruct(w.shape, w.dtype),
|
|
82
|
+
jax.ShapeDtypeStruct(b.shape, b.dtype),
|
|
83
|
+
jax.ShapeDtypeStruct([B, NT, H, K], "float32"),
|
|
84
|
+
]
|
|
85
|
+
dq, dk, dw, db, dgk_last = jt.triton_call(
|
|
86
|
+
v,
|
|
87
|
+
v_new,
|
|
88
|
+
h,
|
|
89
|
+
do,
|
|
90
|
+
dh,
|
|
91
|
+
w,
|
|
92
|
+
dv,
|
|
93
|
+
gk,
|
|
94
|
+
k,
|
|
95
|
+
b,
|
|
96
|
+
T,
|
|
97
|
+
H=H,
|
|
98
|
+
K=K,
|
|
99
|
+
V=V,
|
|
100
|
+
BT=BT,
|
|
101
|
+
BK=BK,
|
|
102
|
+
BV=BV,
|
|
103
|
+
kernel=chunk_dplr_bwd_o_kernel,
|
|
104
|
+
out_shape=out_shapes,
|
|
105
|
+
grid=grid,
|
|
106
|
+
)
|
|
107
|
+
return (dq, dk, dw, db, dgk_last)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def chunk_dplr_bwd_dAu(
|
|
111
|
+
v: jax.Array,
|
|
112
|
+
v_new: jax.Array,
|
|
113
|
+
do: jax.Array,
|
|
114
|
+
A_qb: jax.Array,
|
|
115
|
+
scale: float,
|
|
116
|
+
chunk_size: int = 64,
|
|
117
|
+
) -> jax.Array:
|
|
118
|
+
B, T, H, V = v.shape
|
|
119
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
120
|
+
NT = triton.cdiv(T, BT)
|
|
121
|
+
|
|
122
|
+
if check_shared_mem("ampere"): # A100
|
|
123
|
+
BV = min(triton.next_power_of_2(V), 128)
|
|
124
|
+
elif check_shared_mem("ada"): # 4090
|
|
125
|
+
BV = min(triton.next_power_of_2(V), 64)
|
|
126
|
+
else:
|
|
127
|
+
BV = min(triton.next_power_of_2(V), 32)
|
|
128
|
+
|
|
129
|
+
grid = (NT, B * H)
|
|
130
|
+
out_shapes = [
|
|
131
|
+
jax.ShapeDtypeStruct([B, T, H, BT], "float32"),
|
|
132
|
+
jax.ShapeDtypeStruct([B, T, H, BT], "float32"),
|
|
133
|
+
jax.ShapeDtypeStruct(v_new.shape, v_new.dtype),
|
|
134
|
+
]
|
|
135
|
+
dA_qk, dA_qb, dv_new = jt.triton_call(
|
|
136
|
+
v,
|
|
137
|
+
do,
|
|
138
|
+
v_new,
|
|
139
|
+
A_qb,
|
|
140
|
+
T,
|
|
141
|
+
scale=scale,
|
|
142
|
+
H=H,
|
|
143
|
+
V=V,
|
|
144
|
+
BT=BT,
|
|
145
|
+
BV=BV,
|
|
146
|
+
grid=grid,
|
|
147
|
+
out_shape=out_shapes,
|
|
148
|
+
kernel=chunk_dplr_bwd_kernel_dAu,
|
|
149
|
+
)
|
|
150
|
+
return dv_new, dA_qk, dA_qb
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Copyright (c) 2023-2025,Qingwen Lin
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
import jax_triton as jt
|
|
6
|
+
import jax
|
|
7
|
+
import triton
|
|
8
|
+
|
|
9
|
+
from ..triton_kernel.chunk_o_fwd import *
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def chunk_dplr_fwd_o(
|
|
13
|
+
qg: jax.Array,
|
|
14
|
+
v: jax.Array,
|
|
15
|
+
v_new: jax.Array,
|
|
16
|
+
A_qk: jax.Array,
|
|
17
|
+
A_qb: jax.Array,
|
|
18
|
+
h: jax.Array,
|
|
19
|
+
chunk_size: int = 64,
|
|
20
|
+
) -> jax.Array:
|
|
21
|
+
B, T, H, K, V = *qg.shape, v.shape[-1]
|
|
22
|
+
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
|
23
|
+
|
|
24
|
+
NT = triton.cdiv(T, BT)
|
|
25
|
+
|
|
26
|
+
def grid(meta):
|
|
27
|
+
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
|
28
|
+
|
|
29
|
+
o = jt.triton_call(
|
|
30
|
+
qg,
|
|
31
|
+
v,
|
|
32
|
+
v_new,
|
|
33
|
+
A_qk,
|
|
34
|
+
A_qb,
|
|
35
|
+
h,
|
|
36
|
+
T,
|
|
37
|
+
H=H,
|
|
38
|
+
K=K,
|
|
39
|
+
V=V,
|
|
40
|
+
BT=BT,
|
|
41
|
+
kernel=chunk_dplr_fwd_kernel_o,
|
|
42
|
+
out_shape=jax.ShapeDtypeStruct(v.shape, v.dtype),
|
|
43
|
+
grid=grid,
|
|
44
|
+
)
|
|
45
|
+
return o
|