rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,939 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JAX FFI 版 MHC 算子库
|
|
3
|
+
- Sinkhorn Knopp: 实现双拟随机矩阵投影
|
|
4
|
+
- 接口与 native_keras_op.py 完全一致
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
import pathlib
|
|
9
|
+
import subprocess
|
|
10
|
+
import ctypes
|
|
11
|
+
import numpy as np # <--- 添加numpy导入
|
|
12
|
+
from typing import Tuple
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
from jax.ad_checkpoint import checkpoint_policies as cp
|
|
16
|
+
|
|
17
|
+
# 当前目录
|
|
18
|
+
_CURRENT_DIR = pathlib.Path(__file__).parent.absolute()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# ---------- 延迟编译机制 ----------
|
|
22
|
+
def _ensure_compiled() -> pathlib.Path:
|
|
23
|
+
"""首次调用时编译CUDA扩展"""
|
|
24
|
+
_SO_PATH = _CURRENT_DIR / "mhu.so"
|
|
25
|
+
|
|
26
|
+
if _SO_PATH.exists():
|
|
27
|
+
return _SO_PATH
|
|
28
|
+
|
|
29
|
+
print("[mhu_jax] 首次使用 - 正在编译CUDA内核...")
|
|
30
|
+
|
|
31
|
+
# 构建目录
|
|
32
|
+
_BUILD_DIR = _CURRENT_DIR / "build"
|
|
33
|
+
build_dir = _BUILD_DIR
|
|
34
|
+
build_dir.mkdir(exist_ok=True)
|
|
35
|
+
|
|
36
|
+
# 获取XLA头文件路径
|
|
37
|
+
xla_include_dir = jax.ffi.include_dir()
|
|
38
|
+
if not xla_include_dir:
|
|
39
|
+
raise RuntimeError("jax.ffi.include_dir() 返回空,请检查JAX版本>=0.4.31")
|
|
40
|
+
|
|
41
|
+
# CMake配置
|
|
42
|
+
cmake_args = [
|
|
43
|
+
"cmake",
|
|
44
|
+
"-S",
|
|
45
|
+
str(_CURRENT_DIR),
|
|
46
|
+
"-B",
|
|
47
|
+
str(build_dir),
|
|
48
|
+
"-DCMAKE_BUILD_TYPE=Release",
|
|
49
|
+
f"-DXLA_INCLUDE_DIR={xla_include_dir}",
|
|
50
|
+
"-DCMAKE_CUDA_FLAGS=-O3 --use_fast_math -std=c++17",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
subprocess.check_call(cmake_args, cwd=build_dir)
|
|
55
|
+
subprocess.check_call(["cmake", "--build", str(build_dir), "-j"], cwd=build_dir)
|
|
56
|
+
subprocess.check_call(["cmake", "--install", str(build_dir)], cwd=build_dir)
|
|
57
|
+
except subprocess.CalledProcessError as e:
|
|
58
|
+
raise RuntimeError(f"CMake编译失败: {e}")
|
|
59
|
+
|
|
60
|
+
if not _SO_PATH.exists():
|
|
61
|
+
files = list(_CURRENT_DIR.glob("*"))
|
|
62
|
+
raise RuntimeError(
|
|
63
|
+
f"编译失败 - 无法在 {_SO_PATH} 找到共享库\n"
|
|
64
|
+
f"当前目录内容: {[f.name for f in files]}"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
print(f"[mhu_jax] 编译完成 - 输出: {_SO_PATH}")
|
|
68
|
+
return _SO_PATH
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ---------- FFI目标注册 ----------
|
|
72
|
+
_LIB = ctypes.CDLL(_ensure_compiled())
|
|
73
|
+
jax.ffi.register_ffi_target(
|
|
74
|
+
"sinkhorn_fwd", jax.ffi.pycapsule(_LIB.SinkhornFwd), platform="CUDA"
|
|
75
|
+
)
|
|
76
|
+
jax.ffi.register_ffi_target(
|
|
77
|
+
"sinkhorn_bwd", jax.ffi.pycapsule(_LIB.SinkhornBwd), platform="CUDA"
|
|
78
|
+
)
|
|
79
|
+
jax.ffi.register_ffi_target(
|
|
80
|
+
"rmsnorm_fwd", jax.ffi.pycapsule(_LIB.RMSNormFwd), platform="CUDA"
|
|
81
|
+
)
|
|
82
|
+
jax.ffi.register_ffi_target(
|
|
83
|
+
"rmsnorm_bwd", jax.ffi.pycapsule(_LIB.RMSNormBwd), platform="CUDA"
|
|
84
|
+
)
|
|
85
|
+
jax.ffi.register_ffi_target(
|
|
86
|
+
"stream_mix_fwd", jax.ffi.pycapsule(_LIB.StreamMixFwd), platform="CUDA"
|
|
87
|
+
)
|
|
88
|
+
jax.ffi.register_ffi_target(
|
|
89
|
+
"stream_mix_bwd", jax.ffi.pycapsule(_LIB.StreamMixBwd), platform="CUDA"
|
|
90
|
+
)
|
|
91
|
+
jax.ffi.register_ffi_target(
|
|
92
|
+
"stream_aggregate_fwd", jax.ffi.pycapsule(_LIB.StreamAggregateFwd), platform="CUDA"
|
|
93
|
+
)
|
|
94
|
+
jax.ffi.register_ffi_target(
|
|
95
|
+
"stream_aggregate_bwd", jax.ffi.pycapsule(_LIB.StreamAggregateBwd), platform="CUDA"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _normalize_shape(x: jnp.ndarray, expected_ndim: int, name: str) -> jnp.ndarray:
|
|
100
|
+
"""确保数组维度正确"""
|
|
101
|
+
if x.ndim != expected_ndim:
|
|
102
|
+
raise ValueError(f"{name}期望{expected_ndim}维张量,但输入为{x.ndim}维")
|
|
103
|
+
return x
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# ---------- 核心实现 ----------
|
|
107
|
+
def _sinkhorn_ffi_fwd(
|
|
108
|
+
inp: jnp.ndarray,
|
|
109
|
+
num_iters: np.int32, # <--- 使用np.int32
|
|
110
|
+
eps: np.float32, # <--- 使用np.float32
|
|
111
|
+
) -> jnp.ndarray:
|
|
112
|
+
"""内部FFI前向调用"""
|
|
113
|
+
inp = inp.astype(jnp.float32)
|
|
114
|
+
out_type = jax.ShapeDtypeStruct(inp.shape, jnp.float32)
|
|
115
|
+
|
|
116
|
+
# 直接传递,已经是numpy 32位标量
|
|
117
|
+
out = jax.ffi.ffi_call("sinkhorn_fwd", out_type, vmap_method="broadcast_all")(
|
|
118
|
+
inp, num_iters=num_iters, eps=eps
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return out
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _sinkhorn_ffi_bwd(
|
|
125
|
+
grad: jnp.ndarray,
|
|
126
|
+
out_fwd: jnp.ndarray,
|
|
127
|
+
inp: jnp.ndarray,
|
|
128
|
+
num_iters: np.int32, # <--- 使用np.int32
|
|
129
|
+
eps: np.float32, # <--- 使用np.float32
|
|
130
|
+
) -> jnp.ndarray:
|
|
131
|
+
"""内部FFI反向调用"""
|
|
132
|
+
grad = grad.astype(jnp.float32)
|
|
133
|
+
out_fwd = out_fwd.astype(jnp.float32)
|
|
134
|
+
inp = inp.astype(jnp.float32)
|
|
135
|
+
|
|
136
|
+
d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.float32)
|
|
137
|
+
|
|
138
|
+
d_inp = jax.ffi.ffi_call("sinkhorn_bwd", d_inp_type, vmap_method="broadcast_all")(
|
|
139
|
+
grad, out_fwd, inp, num_iters=num_iters, eps=eps
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return d_inp
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
# 关键修复:在闭包创建时就将参数转换为numpy 32位类型
|
|
146
|
+
def _create_sinkhorn_kernel(num_iters: int, eps: float):
|
|
147
|
+
"""创建带有静态参数的sinkhorn kernel"""
|
|
148
|
+
|
|
149
|
+
# 在闭包外部转换为numpy 32位标量
|
|
150
|
+
num_iters_static = np.int32(num_iters) # <--- 确保32位
|
|
151
|
+
eps_static = np.float32(eps) # <--- 确保32位
|
|
152
|
+
|
|
153
|
+
@jax.custom_vjp
|
|
154
|
+
def _kernel(inp: jnp.ndarray) -> jnp.ndarray:
|
|
155
|
+
return _sinkhorn_ffi_fwd(inp, num_iters_static, eps_static)
|
|
156
|
+
|
|
157
|
+
def _fwd(inp: jnp.ndarray):
|
|
158
|
+
out = _sinkhorn_ffi_fwd(inp, num_iters_static, eps_static)
|
|
159
|
+
return out, (out, inp)
|
|
160
|
+
|
|
161
|
+
def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
|
|
162
|
+
out_fwd, inp = saved_vals
|
|
163
|
+
d_inp = _sinkhorn_ffi_bwd(grad, out_fwd, inp, num_iters_static, eps_static)
|
|
164
|
+
return (d_inp,)
|
|
165
|
+
|
|
166
|
+
_kernel.defvjp(_fwd, _bwd)
|
|
167
|
+
return _kernel
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# ---------- 公共API ----------
|
|
171
|
+
def sinkhorn_knopp(
|
|
172
|
+
inp: jnp.ndarray, num_iters: int = 20, eps: float = 1e-8
|
|
173
|
+
) -> jnp.ndarray:
|
|
174
|
+
"""
|
|
175
|
+
JAX FFI版Sinkhorn Knopp算子
|
|
176
|
+
|
|
177
|
+
参数:
|
|
178
|
+
inp: [B, T, N, N] 输入矩阵(任意dtype)
|
|
179
|
+
num_iters: 迭代次数(必须是编译期常量)
|
|
180
|
+
eps: 防止除零的小常数(必须是编译期常量)
|
|
181
|
+
|
|
182
|
+
返回:
|
|
183
|
+
[B, T, N, N] 双拟随机矩阵,dtype与输入一致
|
|
184
|
+
"""
|
|
185
|
+
# 类型和形状检查
|
|
186
|
+
inp = _normalize_shape(inp, 4, "sinkhorn_knopp")
|
|
187
|
+
original_dtype = inp.dtype
|
|
188
|
+
inp = jnp.asarray(inp, "float32")
|
|
189
|
+
# 关键修复:在创建kernel前转换为numpy 32位类型
|
|
190
|
+
kernel = _create_sinkhorn_kernel(np.int32(num_iters), np.float32(eps))
|
|
191
|
+
|
|
192
|
+
# 使用checkpoint防止重计算
|
|
193
|
+
checkpointed_kernel = jax.checkpoint(
|
|
194
|
+
kernel, policy=cp.save_anything_except_these_names(())
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# 执行计算
|
|
198
|
+
result = checkpointed_kernel(inp)
|
|
199
|
+
|
|
200
|
+
# 转换回原始dtype
|
|
201
|
+
return result.astype(original_dtype)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _rmsnorm_ffi_fwd(inp: jnp.ndarray, eps: np.float32) -> jnp.ndarray:
|
|
205
|
+
"""内部FFI前向调用"""
|
|
206
|
+
# 确保bf16和连续性
|
|
207
|
+
|
|
208
|
+
out_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
|
|
209
|
+
|
|
210
|
+
out = jax.ffi.ffi_call("rmsnorm_fwd", out_type, vmap_method="broadcast_all")(
|
|
211
|
+
inp, eps=eps
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return out
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _rmsnorm_ffi_bwd(
|
|
218
|
+
grad: jnp.ndarray, inp: jnp.ndarray, eps: np.float32
|
|
219
|
+
) -> jnp.ndarray:
|
|
220
|
+
"""内部FFI反向调用"""
|
|
221
|
+
grad = grad.astype(jnp.bfloat16)
|
|
222
|
+
inp = inp.astype(jnp.bfloat16)
|
|
223
|
+
|
|
224
|
+
dx_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
|
|
225
|
+
|
|
226
|
+
dx = jax.ffi.ffi_call("rmsnorm_bwd", dx_type, vmap_method="broadcast_all")(
|
|
227
|
+
grad, inp, eps=eps
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return dx
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _create_rmsnorm_kernel(eps: float):
|
|
234
|
+
"""创建带有静态eps的rmsnorm kernel"""
|
|
235
|
+
eps_static = np.float32(eps) # 编译期常量
|
|
236
|
+
|
|
237
|
+
@jax.custom_vjp
|
|
238
|
+
def _kernel(inp: jnp.ndarray) -> jnp.ndarray:
|
|
239
|
+
return _rmsnorm_ffi_fwd(inp, eps_static)
|
|
240
|
+
|
|
241
|
+
def _fwd(inp: jnp.ndarray):
|
|
242
|
+
out = _rmsnorm_ffi_fwd(inp, eps_static)
|
|
243
|
+
return out, (inp,) # 保存输入用于反向
|
|
244
|
+
|
|
245
|
+
def _bwd(saved_vals: Tuple[jnp.ndarray,], grad: jnp.ndarray):
|
|
246
|
+
(inp,) = saved_vals
|
|
247
|
+
dx = _rmsnorm_ffi_bwd(grad, inp, eps_static)
|
|
248
|
+
return (dx,)
|
|
249
|
+
|
|
250
|
+
_kernel.defvjp(_fwd, _bwd)
|
|
251
|
+
return _kernel
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# ---------- 公共API ----------
|
|
255
|
+
def rmsnorm(inp: jnp.ndarray, eps: float = 1e-5) -> jnp.ndarray:
|
|
256
|
+
"""
|
|
257
|
+
JAX FFI版RMSNorm算子
|
|
258
|
+
|
|
259
|
+
参数:
|
|
260
|
+
inp: [..., C] 输入张量(任意dtype)
|
|
261
|
+
eps: 防止除零的小常数
|
|
262
|
+
|
|
263
|
+
返回:
|
|
264
|
+
[..., C] 归一化结果,dtype与输入一致
|
|
265
|
+
"""
|
|
266
|
+
# 形状检查(至少2维)
|
|
267
|
+
if inp.ndim < 2:
|
|
268
|
+
raise ValueError(f"RMSNorm需要至少2维输入,但得到{inp.ndim}维")
|
|
269
|
+
|
|
270
|
+
original_dtype = inp.dtype
|
|
271
|
+
original_shape = inp.shape
|
|
272
|
+
inp = inp.astype(jnp.bfloat16)
|
|
273
|
+
# 展平到2D: [N, C]
|
|
274
|
+
N = inp.shape[0]
|
|
275
|
+
C = inp.shape[-1]
|
|
276
|
+
inp_2d = inp.reshape(-1, C)
|
|
277
|
+
|
|
278
|
+
# 创建kernel并执行
|
|
279
|
+
kernel = _create_rmsnorm_kernel(eps)
|
|
280
|
+
checkpointed_kernel = jax.checkpoint(
|
|
281
|
+
kernel, policy=cp.save_anything_except_these_names(())
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
result_2d = checkpointed_kernel(inp_2d)
|
|
285
|
+
|
|
286
|
+
# 恢复形状
|
|
287
|
+
return result_2d.astype(original_dtype).reshape(original_shape)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
# ---------- Stream Mix 核心实现 ----------
|
|
291
|
+
def _stream_mix_fwd(inp: jnp.ndarray, M: jnp.ndarray) -> jnp.ndarray:
|
|
292
|
+
"""内部FFI前向调用"""
|
|
293
|
+
# 强制类型转换
|
|
294
|
+
|
|
295
|
+
out_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
|
|
296
|
+
|
|
297
|
+
out = jax.ffi.ffi_call("stream_mix_fwd", out_type, vmap_method="broadcast_all")(
|
|
298
|
+
inp, M
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
return out
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _stream_mix_bwd(
|
|
305
|
+
grad: jnp.ndarray, inp: jnp.ndarray, M: jnp.ndarray
|
|
306
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
307
|
+
"""内部FFI反向调用"""
|
|
308
|
+
# 关键修复:梯度必须是 fp32,不是 bf16
|
|
309
|
+
grad = grad.astype(jnp.float32) # 从 jnp.bfloat16 改为 jnp.float32
|
|
310
|
+
inp = inp.astype(jnp.bfloat16)
|
|
311
|
+
M = M.astype(jnp.float32)
|
|
312
|
+
|
|
313
|
+
d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
|
|
314
|
+
d_M_type = jax.ShapeDtypeStruct(M.shape, jnp.float32)
|
|
315
|
+
|
|
316
|
+
d_inp, d_M = jax.ffi.ffi_call(
|
|
317
|
+
"stream_mix_bwd", (d_inp_type, d_M_type), vmap_method="broadcast_all"
|
|
318
|
+
)(grad, inp, M) # 现在 grad 是 F32,匹配 FFI 签名
|
|
319
|
+
|
|
320
|
+
return d_inp, d_M
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _create_stream_mix_kernel():
|
|
324
|
+
"""创建Stream Mix kernel(无静态参数)"""
|
|
325
|
+
|
|
326
|
+
@jax.custom_vjp
|
|
327
|
+
def _kernel(inp: jnp.ndarray, M: jnp.ndarray) -> jnp.ndarray:
|
|
328
|
+
return _stream_mix_fwd(inp, M)
|
|
329
|
+
|
|
330
|
+
def _fwd(inp: jnp.ndarray, M: jnp.ndarray):
|
|
331
|
+
out = _stream_mix_fwd(inp, M)
|
|
332
|
+
# 保存输入用于反向
|
|
333
|
+
return out, (inp, M)
|
|
334
|
+
|
|
335
|
+
def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
|
|
336
|
+
inp, M = saved_vals
|
|
337
|
+
d_inp, d_M = _stream_mix_bwd(grad, inp, M)
|
|
338
|
+
# 返回两个梯度,对应forward的两个输入
|
|
339
|
+
return d_inp, d_M
|
|
340
|
+
|
|
341
|
+
_kernel.defvjp(_fwd, _bwd)
|
|
342
|
+
return _kernel
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
# ---------- 公共API ----------
|
|
346
|
+
def stream_mix(inp: jnp.ndarray, M: jnp.ndarray) -> jnp.ndarray:
|
|
347
|
+
"""
|
|
348
|
+
JAX FFI版Stream Mix算子
|
|
349
|
+
|
|
350
|
+
参数:
|
|
351
|
+
inp: [B, T, n, C] 输入张量(支持任意dtype,内部转bf16)
|
|
352
|
+
M: [B, T, n, n] 权重矩阵(支持任意dtype,内部转fp32)
|
|
353
|
+
|
|
354
|
+
返回:
|
|
355
|
+
[B, T, n, C] 混合结果,dtype与inp一致
|
|
356
|
+
"""
|
|
357
|
+
# 形状检查
|
|
358
|
+
if inp.ndim != 4:
|
|
359
|
+
raise ValueError(f"Stream Mix需要4维输入,但得到{inp.ndim}维")
|
|
360
|
+
if M.ndim != 4:
|
|
361
|
+
raise ValueError(f"Stream Mix权重需要4维,但得到{M.ndim}维")
|
|
362
|
+
if inp.shape[:3] != M.shape[:3]:
|
|
363
|
+
raise ValueError(f"Batch/Time/Stream维度不匹配: inp{inp.shape}, M{M.shape}")
|
|
364
|
+
|
|
365
|
+
original_dtype = inp.dtype
|
|
366
|
+
inp = inp.astype(jnp.bfloat16)
|
|
367
|
+
M = M.astype(jnp.float32)
|
|
368
|
+
# 创建并执行kernel
|
|
369
|
+
kernel = _create_stream_mix_kernel()
|
|
370
|
+
checkpointed_kernel = jax.checkpoint(
|
|
371
|
+
kernel, policy=cp.save_anything_except_these_names(())
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
result = checkpointed_kernel(inp, M)
|
|
375
|
+
|
|
376
|
+
return result.astype(original_dtype)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
# ---------- Stream Aggregate 核心实现 ----------
|
|
380
|
+
def _stream_aggregate_ffi_fwd(
|
|
381
|
+
inp: jnp.ndarray, H_pre: jnp.ndarray, per_token: bool
|
|
382
|
+
) -> jnp.ndarray:
|
|
383
|
+
"""内部FFI前向调用"""
|
|
384
|
+
# 强制类型转换: BF16 输入, FP32 权重
|
|
385
|
+
|
|
386
|
+
# 输出形状: [B, T, C]
|
|
387
|
+
B, T, n, C = inp.shape
|
|
388
|
+
out_shape = (B, T, C)
|
|
389
|
+
out_type = jax.ShapeDtypeStruct(out_shape, jnp.bfloat16)
|
|
390
|
+
|
|
391
|
+
out = jax.ffi.ffi_call(
|
|
392
|
+
"stream_aggregate_fwd", out_type, vmap_method="broadcast_all"
|
|
393
|
+
)(inp, H_pre, per_token=per_token)
|
|
394
|
+
|
|
395
|
+
return out
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _stream_aggregate_ffi_bwd(
|
|
399
|
+
grad: jnp.ndarray, inp: jnp.ndarray, H_pre: jnp.ndarray, per_token: bool
|
|
400
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
401
|
+
"""内部FFI反向调用"""
|
|
402
|
+
# 关键:梯度用 float32 进行高精度规约
|
|
403
|
+
grad_f32 = grad.astype(jnp.float32)
|
|
404
|
+
inp_bf16 = inp.astype(jnp.bfloat16)
|
|
405
|
+
H_f32 = H_pre.astype(jnp.float32)
|
|
406
|
+
|
|
407
|
+
B, T, n, C = inp.shape
|
|
408
|
+
# 输出梯度形状
|
|
409
|
+
d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
|
|
410
|
+
# d_H_pre 形状需要与 H_pre 保持一致
|
|
411
|
+
d_H_shape = H_pre.shape # 可能是 [B,T,n] 或 [n]
|
|
412
|
+
d_H_type = jax.ShapeDtypeStruct(d_H_shape, jnp.float32)
|
|
413
|
+
|
|
414
|
+
d_inp, d_H = jax.ffi.ffi_call(
|
|
415
|
+
"stream_aggregate_bwd", (d_inp_type, d_H_type), vmap_method="broadcast_all"
|
|
416
|
+
)(grad_f32, inp_bf16, H_f32, per_token=per_token)
|
|
417
|
+
|
|
418
|
+
return d_inp, d_H
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def _create_stream_aggregate_kernel(per_token: bool):
|
|
422
|
+
"""创建Stream Aggregate kernel(支持两种权重模式)"""
|
|
423
|
+
|
|
424
|
+
@jax.custom_vjp
|
|
425
|
+
def _kernel(inp: jnp.ndarray, H_pre: jnp.ndarray) -> jnp.ndarray:
|
|
426
|
+
return _stream_aggregate_ffi_fwd(inp, H_pre, per_token)
|
|
427
|
+
|
|
428
|
+
def _fwd(inp: jnp.ndarray, H_pre: jnp.ndarray):
|
|
429
|
+
out = _stream_aggregate_ffi_fwd(inp, H_pre, per_token)
|
|
430
|
+
# 保存输入用于反向
|
|
431
|
+
return out, (inp, H_pre)
|
|
432
|
+
|
|
433
|
+
def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
|
|
434
|
+
inp, H_pre = saved_vals
|
|
435
|
+
d_inp, d_H_pre = _stream_aggregate_ffi_bwd(grad, inp, H_pre, per_token)
|
|
436
|
+
# 返回两个梯度,对应forward的两个输入
|
|
437
|
+
return d_inp, d_H_pre
|
|
438
|
+
|
|
439
|
+
_kernel.defvjp(_fwd, _bwd)
|
|
440
|
+
return _kernel
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
# ---------- 公共API ----------
|
|
444
|
+
def stream_aggregate(inp: jnp.ndarray, H_pre: jnp.ndarray) -> jnp.ndarray:
|
|
445
|
+
"""
|
|
446
|
+
JAX FFI版Stream Aggregate算子
|
|
447
|
+
|
|
448
|
+
功能: Out = sum(inp * H_pre, axis=-2)
|
|
449
|
+
高精度策略: 在float32空间完成乘法和累加,最后转回输入dtype
|
|
450
|
+
|
|
451
|
+
参数:
|
|
452
|
+
inp: [B, T, n, C] 输入张量(任意dtype)
|
|
453
|
+
H_pre: [B, T, n] 或 [n] 权重张量(任意dtype)
|
|
454
|
+
- [B, T, n]: per-token权重,每个token有独立权重
|
|
455
|
+
- [n]: per-stream权重,所有token共享权重
|
|
456
|
+
|
|
457
|
+
返回:
|
|
458
|
+
[B, T, C] 聚合结果,dtype与inp一致
|
|
459
|
+
"""
|
|
460
|
+
# 形状检查
|
|
461
|
+
if inp.ndim != 4:
|
|
462
|
+
raise ValueError(f"Stream Aggregate需要4维输入,但得到{inp.ndim}维")
|
|
463
|
+
if H_pre.ndim not in [1, 3]:
|
|
464
|
+
raise ValueError(f"H_pre必须是1维或3维,但得到{H_pre.ndim}维")
|
|
465
|
+
|
|
466
|
+
B, T, n, C = inp.shape
|
|
467
|
+
if H_pre.ndim == 1:
|
|
468
|
+
if H_pre.shape[0] != n:
|
|
469
|
+
raise ValueError(f"全局权重H_pre的形状{n}与输入流数{n}不匹配")
|
|
470
|
+
per_token = False
|
|
471
|
+
else: # H_pre.ndim == 3
|
|
472
|
+
if H_pre.shape != (B, T, n):
|
|
473
|
+
raise ValueError(
|
|
474
|
+
f"Per-token权重H_pre形状{H_pre.shape}与输入形状{(B, T, n)}不匹配"
|
|
475
|
+
)
|
|
476
|
+
per_token = True
|
|
477
|
+
|
|
478
|
+
original_dtype = inp.dtype
|
|
479
|
+
|
|
480
|
+
inp = inp.astype(jnp.bfloat16)
|
|
481
|
+
H_pre = H_pre.astype(jnp.float32)
|
|
482
|
+
# 创建并执行kernel
|
|
483
|
+
kernel = _create_stream_aggregate_kernel(per_token)
|
|
484
|
+
checkpointed_kernel = jax.checkpoint(
|
|
485
|
+
kernel, policy=cp.save_anything_except_these_names(())
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
result = checkpointed_kernel(inp, H_pre)
|
|
489
|
+
|
|
490
|
+
return result.astype(original_dtype)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
# 1. 在 register_ffi_target 部分追加
|
|
494
|
+
jax.ffi.register_ffi_target(
|
|
495
|
+
"stream_distribute_fwd",
|
|
496
|
+
jax.ffi.pycapsule(_LIB.StreamDistributeFwd),
|
|
497
|
+
platform="CUDA",
|
|
498
|
+
)
|
|
499
|
+
jax.ffi.register_ffi_target(
|
|
500
|
+
"stream_distribute_bwd",
|
|
501
|
+
jax.ffi.pycapsule(_LIB.StreamDistributeBwd),
|
|
502
|
+
platform="CUDA",
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
# 2. 实现核心逻辑
|
|
507
|
+
def _stream_distribute_ffi_fwd(inp: jnp.ndarray, H_post: jnp.ndarray) -> jnp.ndarray:
|
|
508
|
+
"""内部FFI前向调用"""
|
|
509
|
+
B, T, C = inp.shape
|
|
510
|
+
n = H_post.shape[-1]
|
|
511
|
+
out_type = jax.ShapeDtypeStruct((B, T, n, C), jnp.bfloat16)
|
|
512
|
+
|
|
513
|
+
# 接口对齐:inp用bf16, H_post用f32
|
|
514
|
+
return jax.ffi.ffi_call(
|
|
515
|
+
"stream_distribute_fwd", out_type, vmap_method="broadcast_all"
|
|
516
|
+
)(inp.astype(jnp.bfloat16), H_post.astype(jnp.float32))
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _stream_distribute_ffi_bwd(
|
|
520
|
+
grad: jnp.ndarray, inp: jnp.ndarray, H_post: jnp.ndarray
|
|
521
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
|
522
|
+
"""内部FFI反向调用"""
|
|
523
|
+
# 强制将梯度转为 bf16 匹配 FFI 签名,内部会转 float 计算
|
|
524
|
+
grad_bf16 = grad.astype(jnp.bfloat16)
|
|
525
|
+
inp_bf16 = inp.astype(jnp.bfloat16)
|
|
526
|
+
H_f32 = H_post.astype(jnp.float32)
|
|
527
|
+
|
|
528
|
+
d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
|
|
529
|
+
d_H_type = jax.ShapeDtypeStruct(H_post.shape, jnp.float32)
|
|
530
|
+
|
|
531
|
+
return jax.ffi.ffi_call(
|
|
532
|
+
"stream_distribute_bwd", (d_inp_type, d_H_type), vmap_method="broadcast_all"
|
|
533
|
+
)(grad_bf16, inp_bf16, H_f32)
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
def _create_stream_distribute_kernel():
|
|
537
|
+
"""创建 Stream Distribute kernel"""
|
|
538
|
+
|
|
539
|
+
@jax.custom_vjp
|
|
540
|
+
def _kernel(inp: jnp.ndarray, H_post: jnp.ndarray) -> jnp.ndarray:
|
|
541
|
+
return _stream_distribute_ffi_fwd(inp, H_post)
|
|
542
|
+
|
|
543
|
+
def _fwd(inp: jnp.ndarray, H_post: jnp.ndarray):
|
|
544
|
+
out = _stream_distribute_ffi_fwd(inp, H_post)
|
|
545
|
+
return out, (inp, H_post)
|
|
546
|
+
|
|
547
|
+
def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
|
|
548
|
+
inp, H_post = saved_vals
|
|
549
|
+
d_inp, d_H_post = _stream_distribute_ffi_bwd(grad, inp, H_post)
|
|
550
|
+
return d_inp, d_H_post
|
|
551
|
+
|
|
552
|
+
_kernel.defvjp(_fwd, _bwd)
|
|
553
|
+
return _kernel
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
# 3. 公共 API
|
|
557
|
+
def stream_distribute(inp: jnp.ndarray, H_post: jnp.ndarray) -> jnp.ndarray:
|
|
558
|
+
"""
|
|
559
|
+
JAX FFI 版 Stream Distribute 算子 (1 -> n)
|
|
560
|
+
功能: Out = inp[:, :, None, :] * H_post[:, :, :, None]
|
|
561
|
+
|
|
562
|
+
参数:
|
|
563
|
+
inp: [B, T, C] 输入张量
|
|
564
|
+
H_post: [B, T, n] 权重张量
|
|
565
|
+
返回:
|
|
566
|
+
[B, T, n, C] 分发后的多流张量,dtype 与 inp 一致
|
|
567
|
+
"""
|
|
568
|
+
# 形状检查
|
|
569
|
+
if inp.ndim != 3 or H_post.ndim != 3:
|
|
570
|
+
raise ValueError(
|
|
571
|
+
f"stream_distribute 要求输入均为 3 维,得到 {inp.ndim} 和 {H_post.ndim}"
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
original_dtype = inp.dtype
|
|
575
|
+
inp = inp.astype(jnp.bfloat16)
|
|
576
|
+
H_post = H_post.astype(jnp.float32)
|
|
577
|
+
# 创建并执行 kernel
|
|
578
|
+
kernel = _create_stream_distribute_kernel()
|
|
579
|
+
|
|
580
|
+
# 统一设置永不重计算 (checkpoint policy)
|
|
581
|
+
checkpointed_kernel = jax.checkpoint(
|
|
582
|
+
kernel, policy=cp.save_anything_except_these_names(())
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
result = checkpointed_kernel(inp, H_post)
|
|
586
|
+
|
|
587
|
+
# 类型还原,避免梯度计算中出现不必要的类型漂移
|
|
588
|
+
return result.astype(original_dtype)
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
jax.ffi.register_ffi_target(
|
|
592
|
+
"mhc_post_op_fwd", jax.ffi.pycapsule(_LIB.MhcPostOpFwd), platform="CUDA"
|
|
593
|
+
)
|
|
594
|
+
jax.ffi.register_ffi_target(
|
|
595
|
+
"mhc_post_op_bwd", jax.ffi.pycapsule(_LIB.MhcPostOpBwd), platform="CUDA"
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
# 2. 内部 FFI 调用封装
|
|
600
|
+
def _mhc_post_op_ffi_fwd(layer_out, x_expanded, H_post, H_res):
|
|
601
|
+
out_type = jax.ShapeDtypeStruct(x_expanded.shape, jnp.bfloat16)
|
|
602
|
+
return jax.ffi.ffi_call("mhc_post_op_fwd", out_type, vmap_method="broadcast_all")(
|
|
603
|
+
layer_out.astype(jnp.bfloat16),
|
|
604
|
+
x_expanded.astype(jnp.bfloat16),
|
|
605
|
+
H_post.astype(jnp.float32),
|
|
606
|
+
H_res.astype(jnp.float32),
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _mhc_post_op_ffi_bwd(grad, layer_out, x_expanded, H_post, H_res):
|
|
611
|
+
d_lo_type = jax.ShapeDtypeStruct(layer_out.shape, jnp.bfloat16)
|
|
612
|
+
d_xe_type = jax.ShapeDtypeStruct(x_expanded.shape, jnp.bfloat16)
|
|
613
|
+
d_hp_type = jax.ShapeDtypeStruct(H_post.shape, jnp.float32)
|
|
614
|
+
d_hr_type = jax.ShapeDtypeStruct(H_res.shape, jnp.float32)
|
|
615
|
+
|
|
616
|
+
return jax.ffi.ffi_call(
|
|
617
|
+
"mhc_post_op_bwd",
|
|
618
|
+
(d_lo_type, d_xe_type, d_hp_type, d_hr_type),
|
|
619
|
+
vmap_method="broadcast_all",
|
|
620
|
+
)(
|
|
621
|
+
grad.astype(jnp.bfloat16),
|
|
622
|
+
layer_out.astype(jnp.bfloat16),
|
|
623
|
+
x_expanded.astype(jnp.bfloat16),
|
|
624
|
+
H_post.astype(jnp.float32),
|
|
625
|
+
H_res.astype(jnp.float32),
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def _create_mhc_post_op_kernel():
|
|
630
|
+
@jax.custom_vjp
|
|
631
|
+
def _kernel(layer_out, x_expanded, H_post, H_res):
|
|
632
|
+
return _mhc_post_op_ffi_fwd(layer_out, x_expanded, H_post, H_res)
|
|
633
|
+
|
|
634
|
+
def _fwd(layer_out, x_expanded, H_post, H_res):
|
|
635
|
+
out = _mhc_post_op_ffi_fwd(layer_out, x_expanded, H_post, H_res)
|
|
636
|
+
return out, (layer_out, x_expanded, H_post, H_res)
|
|
637
|
+
|
|
638
|
+
def _bwd(saved_vals, grad):
|
|
639
|
+
lo, xe, hp, hr = saved_vals
|
|
640
|
+
return _mhc_post_op_ffi_bwd(grad, lo, xe, hp, hr)
|
|
641
|
+
|
|
642
|
+
_kernel.defvjp(_fwd, _bwd)
|
|
643
|
+
return _kernel
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
# 3. 公共 API
|
|
647
|
+
def mhc_post_op(
|
|
648
|
+
layer_out: jnp.ndarray,
|
|
649
|
+
x_expanded: jnp.ndarray,
|
|
650
|
+
H_post: jnp.ndarray,
|
|
651
|
+
H_res: jnp.ndarray,
|
|
652
|
+
) -> jnp.ndarray:
|
|
653
|
+
"""
|
|
654
|
+
mHC 后处理融合算子 (Fused Res-Mix + Post-Distribute)
|
|
655
|
+
实现: x_next = (H_res @ x_expanded) + (layer_out * H_post)
|
|
656
|
+
|
|
657
|
+
参数:
|
|
658
|
+
layer_out: [B, T, C] 核心层输出
|
|
659
|
+
x_expanded: [B, T, n, C] 之前的扩展流
|
|
660
|
+
H_post: [B, T, n] 分发权重
|
|
661
|
+
H_res: [B, T, n, n] 混合矩阵
|
|
662
|
+
返回:
|
|
663
|
+
[B, T, n, C] 更新后的流,dtype 与 x_expanded 一致
|
|
664
|
+
"""
|
|
665
|
+
original_dtype = x_expanded.dtype
|
|
666
|
+
layer_out = layer_out.astype(jnp.bfloat16)
|
|
667
|
+
x_expanded = x_expanded.astype(jnp.bfloat16)
|
|
668
|
+
H_post = H_post.astype(jnp.float32)
|
|
669
|
+
H_res = H_res.astype(jnp.float32)
|
|
670
|
+
kernel = _create_mhc_post_op_kernel()
|
|
671
|
+
# 强制 checkpoint 以节省显存并避免冗余重计算
|
|
672
|
+
checkpointed_kernel = jax.checkpoint(
|
|
673
|
+
kernel, policy=cp.save_anything_except_these_names(())
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
result = checkpointed_kernel(layer_out, x_expanded, H_post, H_res)
|
|
677
|
+
return result.astype(original_dtype)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
# ---------- 在 register_ffi_target 部分追加 ----------
|
|
681
|
+
jax.ffi.register_ffi_target(
|
|
682
|
+
"mhc_pre_op_fwd", jax.ffi.pycapsule(_LIB.MhcPreOpFwd), platform="CUDA"
|
|
683
|
+
)
|
|
684
|
+
jax.ffi.register_ffi_target(
|
|
685
|
+
"mhc_pre_op_bwd", jax.ffi.pycapsule(_LIB.MhcPreOpBwd), platform="CUDA"
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
# ---------- MHC Pre-Op 核心实现 ----------
|
|
690
|
+
def _mhc_pre_op_ffi_fwd(
|
|
691
|
+
x_expanded, h_pre_raw, h_post_raw, h_res_raw, sinkhorn_iters, eps
|
|
692
|
+
):
|
|
693
|
+
"""内部FFI前向调用"""
|
|
694
|
+
# x_expanded: [B, T, n, C]
|
|
695
|
+
# h_pre_raw: [B, T, n]
|
|
696
|
+
# h_post_raw: [B, T, n]
|
|
697
|
+
# h_res_raw: [B, T, n, n] 或 [B, T, n*n]
|
|
698
|
+
|
|
699
|
+
# 展平 h_res_raw 以匹配 C++ 接口 (期望 [B, T, n*n])
|
|
700
|
+
if h_res_raw.ndim == 4:
|
|
701
|
+
h_res_raw_flat = h_res_raw.reshape(h_res_raw.shape[0], h_res_raw.shape[1], -1)
|
|
702
|
+
else:
|
|
703
|
+
h_res_raw_flat = h_res_raw
|
|
704
|
+
|
|
705
|
+
B, T, n, C = x_expanded.shape
|
|
706
|
+
|
|
707
|
+
# 定义输出形状 (H_res 由 C++ 返回展平格式)
|
|
708
|
+
out_type_x_layer_in = jax.ShapeDtypeStruct((B, T, C), jnp.bfloat16)
|
|
709
|
+
out_type_H = jax.ShapeDtypeStruct((B, T, n), jnp.float32)
|
|
710
|
+
out_type_H_res_flat = jax.ShapeDtypeStruct((B, T, n * n), jnp.float32)
|
|
711
|
+
|
|
712
|
+
# 调用 FFI 前向 (返回 4 个张量)
|
|
713
|
+
x_layer_in, H_pre, H_post, H_res_flat = jax.ffi.ffi_call(
|
|
714
|
+
"mhc_pre_op_fwd",
|
|
715
|
+
(out_type_x_layer_in, out_type_H, out_type_H, out_type_H_res_flat),
|
|
716
|
+
vmap_method="broadcast_all",
|
|
717
|
+
)(
|
|
718
|
+
x_expanded.astype(jnp.bfloat16),
|
|
719
|
+
h_pre_raw.astype(jnp.float32),
|
|
720
|
+
h_post_raw.astype(jnp.float32),
|
|
721
|
+
h_res_raw_flat.astype(jnp.float32),
|
|
722
|
+
sinkhorn_iters=sinkhorn_iters,
|
|
723
|
+
eps=eps,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# 将 H_res 重塑回 4D
|
|
727
|
+
H_res = H_res_flat.reshape(B, T, n, n)
|
|
728
|
+
return x_layer_in, H_pre, H_post, H_res
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def _mhc_pre_op_ffi_bwd(
|
|
732
|
+
grad_layer_in,
|
|
733
|
+
grad_H_post,
|
|
734
|
+
grad_H_res,
|
|
735
|
+
x_expanded,
|
|
736
|
+
H_pre,
|
|
737
|
+
H_post,
|
|
738
|
+
H_res_out,
|
|
739
|
+
h_res_raw,
|
|
740
|
+
sinkhorn_iters,
|
|
741
|
+
eps,
|
|
742
|
+
):
|
|
743
|
+
"""内部FFI反向调用"""
|
|
744
|
+
# 展平梯度与残差以匹配 C++ 接口
|
|
745
|
+
grad_H_res_flat = grad_H_res.reshape(grad_H_res.shape[0], grad_H_res.shape[1], -1)
|
|
746
|
+
H_res_out_flat = H_res_out.reshape(H_res_out.shape[0], H_res_out.shape[1], -1)
|
|
747
|
+
|
|
748
|
+
# h_res_raw 是来自用户的原始输入,可能为 4D
|
|
749
|
+
if h_res_raw.ndim == 4:
|
|
750
|
+
h_res_raw_flat = h_res_raw.reshape(h_res_raw.shape[0], h_res_raw.shape[1], -1)
|
|
751
|
+
else:
|
|
752
|
+
h_res_raw_flat = h_res_raw
|
|
753
|
+
|
|
754
|
+
B, T, n, C = x_expanded.shape
|
|
755
|
+
d_x_shape = (B, T, n, C)
|
|
756
|
+
d_h_shape = (B, T, n)
|
|
757
|
+
d_h_res_shape = h_res_raw_flat.shape # 展平格式
|
|
758
|
+
|
|
759
|
+
d_x_type = jax.ShapeDtypeStruct(d_x_shape, jnp.bfloat16)
|
|
760
|
+
d_h_type = jax.ShapeDtypeStruct(d_h_shape, jnp.float32)
|
|
761
|
+
d_h_res_type = jax.ShapeDtypeStruct(d_h_res_shape, jnp.float32)
|
|
762
|
+
|
|
763
|
+
# 调用 FFI 反向 (返回 4 个梯度)
|
|
764
|
+
d_x_expanded, d_h_pre_raw, d_h_post_raw, d_h_res_raw_flat = jax.ffi.ffi_call(
|
|
765
|
+
"mhc_pre_op_bwd",
|
|
766
|
+
(d_x_type, d_h_type, d_h_type, d_h_res_type),
|
|
767
|
+
vmap_method="broadcast_all",
|
|
768
|
+
)(
|
|
769
|
+
grad_layer_in.astype(jnp.bfloat16),
|
|
770
|
+
grad_H_post.astype(jnp.float32),
|
|
771
|
+
grad_H_res_flat.astype(jnp.float32),
|
|
772
|
+
x_expanded.astype(jnp.bfloat16),
|
|
773
|
+
H_pre.astype(jnp.float32),
|
|
774
|
+
H_post.astype(jnp.float32),
|
|
775
|
+
H_res_out_flat.astype(jnp.float32),
|
|
776
|
+
h_res_raw_flat.astype(jnp.float32),
|
|
777
|
+
sinkhorn_iters=sinkhorn_iters,
|
|
778
|
+
eps=eps,
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
# 将 d_h_res_raw 重塑回 4D (若原始输入是 4D)
|
|
782
|
+
if h_res_raw.ndim == 4:
|
|
783
|
+
d_h_res_raw = d_h_res_raw_flat.reshape(B, T, n, n)
|
|
784
|
+
else:
|
|
785
|
+
d_h_res_raw = d_h_res_raw_flat
|
|
786
|
+
|
|
787
|
+
return d_x_expanded, d_h_pre_raw, d_h_post_raw, d_h_res_raw
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
def _create_mhc_pre_op_kernel(sinkhorn_iters: int, eps: float):
|
|
791
|
+
"""创建MHC Pre-Op kernel(静态参数固化)"""
|
|
792
|
+
# 在闭包外部固化 Sinkhorn 参数为 NumPy 32 位标量
|
|
793
|
+
sinkhorn_iters_static = np.int32(sinkhorn_iters)
|
|
794
|
+
eps_static = np.float32(eps)
|
|
795
|
+
|
|
796
|
+
@jax.custom_vjp
|
|
797
|
+
def _kernel(x_expanded, h_pre_raw, h_post_raw, h_res_raw):
|
|
798
|
+
# 调用 FFI 前向,返回 4 个张量
|
|
799
|
+
x_layer_in, H_pre, H_post, H_res = _mhc_pre_op_ffi_fwd(
|
|
800
|
+
x_expanded,
|
|
801
|
+
h_pre_raw,
|
|
802
|
+
h_post_raw,
|
|
803
|
+
h_res_raw,
|
|
804
|
+
sinkhorn_iters_static,
|
|
805
|
+
eps_static,
|
|
806
|
+
)
|
|
807
|
+
# 只返回 3 个张量给用户 (PyTorch 版本不返回 H_pre)
|
|
808
|
+
return x_layer_in, H_post, H_res
|
|
809
|
+
|
|
810
|
+
def _fwd(x_expanded, h_pre_raw, h_post_raw, h_res_raw):
|
|
811
|
+
# 调用前向并保存残差
|
|
812
|
+
x_layer_in, H_pre, H_post, H_res = _mhc_pre_op_ffi_fwd(
|
|
813
|
+
x_expanded,
|
|
814
|
+
h_pre_raw,
|
|
815
|
+
h_post_raw,
|
|
816
|
+
h_res_raw,
|
|
817
|
+
sinkhorn_iters_static,
|
|
818
|
+
eps_static,
|
|
819
|
+
)
|
|
820
|
+
# 残差包含反向所需的所有张量 (包括 H_pre, H_post, H_res)
|
|
821
|
+
return (x_layer_in, H_post, H_res), (
|
|
822
|
+
x_expanded,
|
|
823
|
+
H_pre,
|
|
824
|
+
H_post,
|
|
825
|
+
H_res,
|
|
826
|
+
h_res_raw,
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
def _bwd(residuals, grads):
|
|
830
|
+
# 解包残差
|
|
831
|
+
x_expanded, H_pre, H_post, H_res, h_res_raw = residuals
|
|
832
|
+
# 解包输出梯度 (3 个梯度,对应前向的 3 个输出)
|
|
833
|
+
# grads: (grad_x_layer_in, grad_H_post, grad_H_res)
|
|
834
|
+
grad_layer_in = grads[0]
|
|
835
|
+
grad_H_post = grads[1]
|
|
836
|
+
grad_H_res = grads[2]
|
|
837
|
+
|
|
838
|
+
# 调用 FFI 反向
|
|
839
|
+
return _mhc_pre_op_ffi_bwd(
|
|
840
|
+
grad_layer_in,
|
|
841
|
+
grad_H_post,
|
|
842
|
+
grad_H_res,
|
|
843
|
+
x_expanded,
|
|
844
|
+
H_pre,
|
|
845
|
+
H_post,
|
|
846
|
+
H_res,
|
|
847
|
+
h_res_raw,
|
|
848
|
+
sinkhorn_iters_static,
|
|
849
|
+
eps_static,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
_kernel.defvjp(_fwd, _bwd)
|
|
853
|
+
return _kernel
|
|
854
|
+
|
|
855
|
+
|
|
856
|
+
# ---------- 公共 API ----------
|
|
857
|
+
def mhc_pre_op(
|
|
858
|
+
x_expanded: jnp.ndarray,
|
|
859
|
+
h_pre_raw: jnp.ndarray,
|
|
860
|
+
h_post_raw: jnp.ndarray,
|
|
861
|
+
h_res_raw: jnp.ndarray,
|
|
862
|
+
num_iters: int = 20,
|
|
863
|
+
eps: float = 1e-8,
|
|
864
|
+
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
|
865
|
+
"""
|
|
866
|
+
mHC 前处理融合算子 (Fused Aggregate + Sigmoid + Sinkhorn)
|
|
867
|
+
|
|
868
|
+
功能:
|
|
869
|
+
1. H_pre = sigmoid(h_pre_raw)
|
|
870
|
+
2. H_post = 2 * sigmoid(h_post_raw)
|
|
871
|
+
3. H_res = Sinkhorn(h_res_raw)
|
|
872
|
+
4. x_layer_in = sum(H_pre * x_expanded, axis=-2)
|
|
873
|
+
|
|
874
|
+
参数:
|
|
875
|
+
x_expanded: [B, T, n, C] 输入张量(任意dtype)
|
|
876
|
+
h_pre_raw: [B, T, n] 预权重原始值(任意dtype)
|
|
877
|
+
h_post_raw: [B, T, n] 后权重原始值(任意dtype)
|
|
878
|
+
h_res_raw: [B, T, n, n] 残差权重原始值(任意dtype)
|
|
879
|
+
num_iters: Sinkhorn 迭代次数(编译期常量)
|
|
880
|
+
eps: 防止除零的小常数(编译期常量)
|
|
881
|
+
|
|
882
|
+
返回:
|
|
883
|
+
tuple: (x_layer_in [B, T, C], H_post [B, T, n], H_res [B, T, n, n])
|
|
884
|
+
dtype 与输入一致
|
|
885
|
+
"""
|
|
886
|
+
# 形状检查
|
|
887
|
+
if x_expanded.ndim != 4:
|
|
888
|
+
raise ValueError(f"x_expanded 需要 4 维,但得到 {x_expanded.ndim}")
|
|
889
|
+
B, T, n, C = x_expanded.shape
|
|
890
|
+
|
|
891
|
+
expected_h_shape = (B, T, n)
|
|
892
|
+
if h_pre_raw.shape != expected_h_shape:
|
|
893
|
+
raise ValueError(
|
|
894
|
+
f"h_pre_raw 形状 {h_pre_raw.shape} 与期望 {expected_h_shape} 不匹配"
|
|
895
|
+
)
|
|
896
|
+
if h_post_raw.shape != expected_h_shape:
|
|
897
|
+
raise ValueError(
|
|
898
|
+
f"h_post_raw 形状 {h_post_raw.shape} 与期望 {expected_h_shape} 不匹配"
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
# h_res_raw 可以是 4D [B,T,n,n] 或 3D [B,T,n*n]
|
|
902
|
+
if h_res_raw.ndim == 4:
|
|
903
|
+
if h_res_raw.shape != (B, T, n, n):
|
|
904
|
+
raise ValueError(
|
|
905
|
+
f"h_res_raw 4D 形状 {h_res_raw.shape} 与期望 {(B, T, n, n)} 不匹配"
|
|
906
|
+
)
|
|
907
|
+
elif h_res_raw.ndim == 3:
|
|
908
|
+
if h_res_raw.shape != (B, T, n * n):
|
|
909
|
+
raise ValueError(
|
|
910
|
+
f"h_res_raw 3D 形状 {h_res_raw.shape} 与期望 {(B, T, n * n)} 不匹配"
|
|
911
|
+
)
|
|
912
|
+
else:
|
|
913
|
+
raise ValueError(f"h_res_raw 必须是 3 维或 4 维,但得到 {h_res_raw.ndim}")
|
|
914
|
+
|
|
915
|
+
original_dtype_x = x_expanded.dtype
|
|
916
|
+
original_dtype_h = h_pre_raw.dtype # 假设所有 h 张量 dtype 相同
|
|
917
|
+
|
|
918
|
+
# 类型转换:激活用 bf16,参数用 fp32
|
|
919
|
+
x_expanded = x_expanded.astype(jnp.bfloat16)
|
|
920
|
+
h_pre_raw = h_pre_raw.astype(jnp.float32)
|
|
921
|
+
h_post_raw = h_post_raw.astype(jnp.float32)
|
|
922
|
+
h_res_raw = h_res_raw.astype(jnp.float32)
|
|
923
|
+
|
|
924
|
+
# 创建并执行 kernel
|
|
925
|
+
kernel = _create_mhc_pre_op_kernel(num_iters, eps)
|
|
926
|
+
checkpointed_kernel = jax.checkpoint(
|
|
927
|
+
kernel, policy=cp.save_anything_except_these_names(())
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
x_layer_in, H_post, H_res = checkpointed_kernel(
|
|
931
|
+
x_expanded, h_pre_raw, h_post_raw, h_res_raw
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
# 类型还原
|
|
935
|
+
return (
|
|
936
|
+
x_layer_in.astype(original_dtype_x),
|
|
937
|
+
H_post.astype(original_dtype_h),
|
|
938
|
+
H_res.astype(original_dtype_h),
|
|
939
|
+
)
|