rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,939 @@
1
+ """
2
+ JAX FFI 版 MHC 算子库
3
+ - Sinkhorn Knopp: 实现双拟随机矩阵投影
4
+ - 接口与 native_keras_op.py 完全一致
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import pathlib
9
+ import subprocess
10
+ import ctypes
11
+ import numpy as np # <--- 添加numpy导入
12
+ from typing import Tuple
13
+ import jax
14
+ import jax.numpy as jnp
15
+ from jax.ad_checkpoint import checkpoint_policies as cp
16
+
17
+ # 当前目录
18
+ _CURRENT_DIR = pathlib.Path(__file__).parent.absolute()
19
+
20
+
21
+ # ---------- 延迟编译机制 ----------
22
+ def _ensure_compiled() -> pathlib.Path:
23
+ """首次调用时编译CUDA扩展"""
24
+ _SO_PATH = _CURRENT_DIR / "mhu.so"
25
+
26
+ if _SO_PATH.exists():
27
+ return _SO_PATH
28
+
29
+ print("[mhu_jax] 首次使用 - 正在编译CUDA内核...")
30
+
31
+ # 构建目录
32
+ _BUILD_DIR = _CURRENT_DIR / "build"
33
+ build_dir = _BUILD_DIR
34
+ build_dir.mkdir(exist_ok=True)
35
+
36
+ # 获取XLA头文件路径
37
+ xla_include_dir = jax.ffi.include_dir()
38
+ if not xla_include_dir:
39
+ raise RuntimeError("jax.ffi.include_dir() 返回空,请检查JAX版本>=0.4.31")
40
+
41
+ # CMake配置
42
+ cmake_args = [
43
+ "cmake",
44
+ "-S",
45
+ str(_CURRENT_DIR),
46
+ "-B",
47
+ str(build_dir),
48
+ "-DCMAKE_BUILD_TYPE=Release",
49
+ f"-DXLA_INCLUDE_DIR={xla_include_dir}",
50
+ "-DCMAKE_CUDA_FLAGS=-O3 --use_fast_math -std=c++17",
51
+ ]
52
+
53
+ try:
54
+ subprocess.check_call(cmake_args, cwd=build_dir)
55
+ subprocess.check_call(["cmake", "--build", str(build_dir), "-j"], cwd=build_dir)
56
+ subprocess.check_call(["cmake", "--install", str(build_dir)], cwd=build_dir)
57
+ except subprocess.CalledProcessError as e:
58
+ raise RuntimeError(f"CMake编译失败: {e}")
59
+
60
+ if not _SO_PATH.exists():
61
+ files = list(_CURRENT_DIR.glob("*"))
62
+ raise RuntimeError(
63
+ f"编译失败 - 无法在 {_SO_PATH} 找到共享库\n"
64
+ f"当前目录内容: {[f.name for f in files]}"
65
+ )
66
+
67
+ print(f"[mhu_jax] 编译完成 - 输出: {_SO_PATH}")
68
+ return _SO_PATH
69
+
70
+
71
+ # ---------- FFI目标注册 ----------
72
+ _LIB = ctypes.CDLL(_ensure_compiled())
73
+ jax.ffi.register_ffi_target(
74
+ "sinkhorn_fwd", jax.ffi.pycapsule(_LIB.SinkhornFwd), platform="CUDA"
75
+ )
76
+ jax.ffi.register_ffi_target(
77
+ "sinkhorn_bwd", jax.ffi.pycapsule(_LIB.SinkhornBwd), platform="CUDA"
78
+ )
79
+ jax.ffi.register_ffi_target(
80
+ "rmsnorm_fwd", jax.ffi.pycapsule(_LIB.RMSNormFwd), platform="CUDA"
81
+ )
82
+ jax.ffi.register_ffi_target(
83
+ "rmsnorm_bwd", jax.ffi.pycapsule(_LIB.RMSNormBwd), platform="CUDA"
84
+ )
85
+ jax.ffi.register_ffi_target(
86
+ "stream_mix_fwd", jax.ffi.pycapsule(_LIB.StreamMixFwd), platform="CUDA"
87
+ )
88
+ jax.ffi.register_ffi_target(
89
+ "stream_mix_bwd", jax.ffi.pycapsule(_LIB.StreamMixBwd), platform="CUDA"
90
+ )
91
+ jax.ffi.register_ffi_target(
92
+ "stream_aggregate_fwd", jax.ffi.pycapsule(_LIB.StreamAggregateFwd), platform="CUDA"
93
+ )
94
+ jax.ffi.register_ffi_target(
95
+ "stream_aggregate_bwd", jax.ffi.pycapsule(_LIB.StreamAggregateBwd), platform="CUDA"
96
+ )
97
+
98
+
99
+ def _normalize_shape(x: jnp.ndarray, expected_ndim: int, name: str) -> jnp.ndarray:
100
+ """确保数组维度正确"""
101
+ if x.ndim != expected_ndim:
102
+ raise ValueError(f"{name}期望{expected_ndim}维张量,但输入为{x.ndim}维")
103
+ return x
104
+
105
+
106
+ # ---------- 核心实现 ----------
107
+ def _sinkhorn_ffi_fwd(
108
+ inp: jnp.ndarray,
109
+ num_iters: np.int32, # <--- 使用np.int32
110
+ eps: np.float32, # <--- 使用np.float32
111
+ ) -> jnp.ndarray:
112
+ """内部FFI前向调用"""
113
+ inp = inp.astype(jnp.float32)
114
+ out_type = jax.ShapeDtypeStruct(inp.shape, jnp.float32)
115
+
116
+ # 直接传递,已经是numpy 32位标量
117
+ out = jax.ffi.ffi_call("sinkhorn_fwd", out_type, vmap_method="broadcast_all")(
118
+ inp, num_iters=num_iters, eps=eps
119
+ )
120
+
121
+ return out
122
+
123
+
124
+ def _sinkhorn_ffi_bwd(
125
+ grad: jnp.ndarray,
126
+ out_fwd: jnp.ndarray,
127
+ inp: jnp.ndarray,
128
+ num_iters: np.int32, # <--- 使用np.int32
129
+ eps: np.float32, # <--- 使用np.float32
130
+ ) -> jnp.ndarray:
131
+ """内部FFI反向调用"""
132
+ grad = grad.astype(jnp.float32)
133
+ out_fwd = out_fwd.astype(jnp.float32)
134
+ inp = inp.astype(jnp.float32)
135
+
136
+ d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.float32)
137
+
138
+ d_inp = jax.ffi.ffi_call("sinkhorn_bwd", d_inp_type, vmap_method="broadcast_all")(
139
+ grad, out_fwd, inp, num_iters=num_iters, eps=eps
140
+ )
141
+
142
+ return d_inp
143
+
144
+
145
+ # 关键修复:在闭包创建时就将参数转换为numpy 32位类型
146
+ def _create_sinkhorn_kernel(num_iters: int, eps: float):
147
+ """创建带有静态参数的sinkhorn kernel"""
148
+
149
+ # 在闭包外部转换为numpy 32位标量
150
+ num_iters_static = np.int32(num_iters) # <--- 确保32位
151
+ eps_static = np.float32(eps) # <--- 确保32位
152
+
153
+ @jax.custom_vjp
154
+ def _kernel(inp: jnp.ndarray) -> jnp.ndarray:
155
+ return _sinkhorn_ffi_fwd(inp, num_iters_static, eps_static)
156
+
157
+ def _fwd(inp: jnp.ndarray):
158
+ out = _sinkhorn_ffi_fwd(inp, num_iters_static, eps_static)
159
+ return out, (out, inp)
160
+
161
+ def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
162
+ out_fwd, inp = saved_vals
163
+ d_inp = _sinkhorn_ffi_bwd(grad, out_fwd, inp, num_iters_static, eps_static)
164
+ return (d_inp,)
165
+
166
+ _kernel.defvjp(_fwd, _bwd)
167
+ return _kernel
168
+
169
+
170
+ # ---------- 公共API ----------
171
+ def sinkhorn_knopp(
172
+ inp: jnp.ndarray, num_iters: int = 20, eps: float = 1e-8
173
+ ) -> jnp.ndarray:
174
+ """
175
+ JAX FFI版Sinkhorn Knopp算子
176
+
177
+ 参数:
178
+ inp: [B, T, N, N] 输入矩阵(任意dtype)
179
+ num_iters: 迭代次数(必须是编译期常量)
180
+ eps: 防止除零的小常数(必须是编译期常量)
181
+
182
+ 返回:
183
+ [B, T, N, N] 双拟随机矩阵,dtype与输入一致
184
+ """
185
+ # 类型和形状检查
186
+ inp = _normalize_shape(inp, 4, "sinkhorn_knopp")
187
+ original_dtype = inp.dtype
188
+ inp = jnp.asarray(inp, "float32")
189
+ # 关键修复:在创建kernel前转换为numpy 32位类型
190
+ kernel = _create_sinkhorn_kernel(np.int32(num_iters), np.float32(eps))
191
+
192
+ # 使用checkpoint防止重计算
193
+ checkpointed_kernel = jax.checkpoint(
194
+ kernel, policy=cp.save_anything_except_these_names(())
195
+ )
196
+
197
+ # 执行计算
198
+ result = checkpointed_kernel(inp)
199
+
200
+ # 转换回原始dtype
201
+ return result.astype(original_dtype)
202
+
203
+
204
+ def _rmsnorm_ffi_fwd(inp: jnp.ndarray, eps: np.float32) -> jnp.ndarray:
205
+ """内部FFI前向调用"""
206
+ # 确保bf16和连续性
207
+
208
+ out_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
209
+
210
+ out = jax.ffi.ffi_call("rmsnorm_fwd", out_type, vmap_method="broadcast_all")(
211
+ inp, eps=eps
212
+ )
213
+
214
+ return out
215
+
216
+
217
+ def _rmsnorm_ffi_bwd(
218
+ grad: jnp.ndarray, inp: jnp.ndarray, eps: np.float32
219
+ ) -> jnp.ndarray:
220
+ """内部FFI反向调用"""
221
+ grad = grad.astype(jnp.bfloat16)
222
+ inp = inp.astype(jnp.bfloat16)
223
+
224
+ dx_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
225
+
226
+ dx = jax.ffi.ffi_call("rmsnorm_bwd", dx_type, vmap_method="broadcast_all")(
227
+ grad, inp, eps=eps
228
+ )
229
+
230
+ return dx
231
+
232
+
233
+ def _create_rmsnorm_kernel(eps: float):
234
+ """创建带有静态eps的rmsnorm kernel"""
235
+ eps_static = np.float32(eps) # 编译期常量
236
+
237
+ @jax.custom_vjp
238
+ def _kernel(inp: jnp.ndarray) -> jnp.ndarray:
239
+ return _rmsnorm_ffi_fwd(inp, eps_static)
240
+
241
+ def _fwd(inp: jnp.ndarray):
242
+ out = _rmsnorm_ffi_fwd(inp, eps_static)
243
+ return out, (inp,) # 保存输入用于反向
244
+
245
+ def _bwd(saved_vals: Tuple[jnp.ndarray,], grad: jnp.ndarray):
246
+ (inp,) = saved_vals
247
+ dx = _rmsnorm_ffi_bwd(grad, inp, eps_static)
248
+ return (dx,)
249
+
250
+ _kernel.defvjp(_fwd, _bwd)
251
+ return _kernel
252
+
253
+
254
+ # ---------- 公共API ----------
255
+ def rmsnorm(inp: jnp.ndarray, eps: float = 1e-5) -> jnp.ndarray:
256
+ """
257
+ JAX FFI版RMSNorm算子
258
+
259
+ 参数:
260
+ inp: [..., C] 输入张量(任意dtype)
261
+ eps: 防止除零的小常数
262
+
263
+ 返回:
264
+ [..., C] 归一化结果,dtype与输入一致
265
+ """
266
+ # 形状检查(至少2维)
267
+ if inp.ndim < 2:
268
+ raise ValueError(f"RMSNorm需要至少2维输入,但得到{inp.ndim}维")
269
+
270
+ original_dtype = inp.dtype
271
+ original_shape = inp.shape
272
+ inp = inp.astype(jnp.bfloat16)
273
+ # 展平到2D: [N, C]
274
+ N = inp.shape[0]
275
+ C = inp.shape[-1]
276
+ inp_2d = inp.reshape(-1, C)
277
+
278
+ # 创建kernel并执行
279
+ kernel = _create_rmsnorm_kernel(eps)
280
+ checkpointed_kernel = jax.checkpoint(
281
+ kernel, policy=cp.save_anything_except_these_names(())
282
+ )
283
+
284
+ result_2d = checkpointed_kernel(inp_2d)
285
+
286
+ # 恢复形状
287
+ return result_2d.astype(original_dtype).reshape(original_shape)
288
+
289
+
290
+ # ---------- Stream Mix 核心实现 ----------
291
+ def _stream_mix_fwd(inp: jnp.ndarray, M: jnp.ndarray) -> jnp.ndarray:
292
+ """内部FFI前向调用"""
293
+ # 强制类型转换
294
+
295
+ out_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
296
+
297
+ out = jax.ffi.ffi_call("stream_mix_fwd", out_type, vmap_method="broadcast_all")(
298
+ inp, M
299
+ )
300
+
301
+ return out
302
+
303
+
304
+ def _stream_mix_bwd(
305
+ grad: jnp.ndarray, inp: jnp.ndarray, M: jnp.ndarray
306
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
307
+ """内部FFI反向调用"""
308
+ # 关键修复:梯度必须是 fp32,不是 bf16
309
+ grad = grad.astype(jnp.float32) # 从 jnp.bfloat16 改为 jnp.float32
310
+ inp = inp.astype(jnp.bfloat16)
311
+ M = M.astype(jnp.float32)
312
+
313
+ d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
314
+ d_M_type = jax.ShapeDtypeStruct(M.shape, jnp.float32)
315
+
316
+ d_inp, d_M = jax.ffi.ffi_call(
317
+ "stream_mix_bwd", (d_inp_type, d_M_type), vmap_method="broadcast_all"
318
+ )(grad, inp, M) # 现在 grad 是 F32,匹配 FFI 签名
319
+
320
+ return d_inp, d_M
321
+
322
+
323
+ def _create_stream_mix_kernel():
324
+ """创建Stream Mix kernel(无静态参数)"""
325
+
326
+ @jax.custom_vjp
327
+ def _kernel(inp: jnp.ndarray, M: jnp.ndarray) -> jnp.ndarray:
328
+ return _stream_mix_fwd(inp, M)
329
+
330
+ def _fwd(inp: jnp.ndarray, M: jnp.ndarray):
331
+ out = _stream_mix_fwd(inp, M)
332
+ # 保存输入用于反向
333
+ return out, (inp, M)
334
+
335
+ def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
336
+ inp, M = saved_vals
337
+ d_inp, d_M = _stream_mix_bwd(grad, inp, M)
338
+ # 返回两个梯度,对应forward的两个输入
339
+ return d_inp, d_M
340
+
341
+ _kernel.defvjp(_fwd, _bwd)
342
+ return _kernel
343
+
344
+
345
+ # ---------- 公共API ----------
346
+ def stream_mix(inp: jnp.ndarray, M: jnp.ndarray) -> jnp.ndarray:
347
+ """
348
+ JAX FFI版Stream Mix算子
349
+
350
+ 参数:
351
+ inp: [B, T, n, C] 输入张量(支持任意dtype,内部转bf16)
352
+ M: [B, T, n, n] 权重矩阵(支持任意dtype,内部转fp32)
353
+
354
+ 返回:
355
+ [B, T, n, C] 混合结果,dtype与inp一致
356
+ """
357
+ # 形状检查
358
+ if inp.ndim != 4:
359
+ raise ValueError(f"Stream Mix需要4维输入,但得到{inp.ndim}维")
360
+ if M.ndim != 4:
361
+ raise ValueError(f"Stream Mix权重需要4维,但得到{M.ndim}维")
362
+ if inp.shape[:3] != M.shape[:3]:
363
+ raise ValueError(f"Batch/Time/Stream维度不匹配: inp{inp.shape}, M{M.shape}")
364
+
365
+ original_dtype = inp.dtype
366
+ inp = inp.astype(jnp.bfloat16)
367
+ M = M.astype(jnp.float32)
368
+ # 创建并执行kernel
369
+ kernel = _create_stream_mix_kernel()
370
+ checkpointed_kernel = jax.checkpoint(
371
+ kernel, policy=cp.save_anything_except_these_names(())
372
+ )
373
+
374
+ result = checkpointed_kernel(inp, M)
375
+
376
+ return result.astype(original_dtype)
377
+
378
+
379
+ # ---------- Stream Aggregate 核心实现 ----------
380
+ def _stream_aggregate_ffi_fwd(
381
+ inp: jnp.ndarray, H_pre: jnp.ndarray, per_token: bool
382
+ ) -> jnp.ndarray:
383
+ """内部FFI前向调用"""
384
+ # 强制类型转换: BF16 输入, FP32 权重
385
+
386
+ # 输出形状: [B, T, C]
387
+ B, T, n, C = inp.shape
388
+ out_shape = (B, T, C)
389
+ out_type = jax.ShapeDtypeStruct(out_shape, jnp.bfloat16)
390
+
391
+ out = jax.ffi.ffi_call(
392
+ "stream_aggregate_fwd", out_type, vmap_method="broadcast_all"
393
+ )(inp, H_pre, per_token=per_token)
394
+
395
+ return out
396
+
397
+
398
+ def _stream_aggregate_ffi_bwd(
399
+ grad: jnp.ndarray, inp: jnp.ndarray, H_pre: jnp.ndarray, per_token: bool
400
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
401
+ """内部FFI反向调用"""
402
+ # 关键:梯度用 float32 进行高精度规约
403
+ grad_f32 = grad.astype(jnp.float32)
404
+ inp_bf16 = inp.astype(jnp.bfloat16)
405
+ H_f32 = H_pre.astype(jnp.float32)
406
+
407
+ B, T, n, C = inp.shape
408
+ # 输出梯度形状
409
+ d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
410
+ # d_H_pre 形状需要与 H_pre 保持一致
411
+ d_H_shape = H_pre.shape # 可能是 [B,T,n] 或 [n]
412
+ d_H_type = jax.ShapeDtypeStruct(d_H_shape, jnp.float32)
413
+
414
+ d_inp, d_H = jax.ffi.ffi_call(
415
+ "stream_aggregate_bwd", (d_inp_type, d_H_type), vmap_method="broadcast_all"
416
+ )(grad_f32, inp_bf16, H_f32, per_token=per_token)
417
+
418
+ return d_inp, d_H
419
+
420
+
421
+ def _create_stream_aggregate_kernel(per_token: bool):
422
+ """创建Stream Aggregate kernel(支持两种权重模式)"""
423
+
424
+ @jax.custom_vjp
425
+ def _kernel(inp: jnp.ndarray, H_pre: jnp.ndarray) -> jnp.ndarray:
426
+ return _stream_aggregate_ffi_fwd(inp, H_pre, per_token)
427
+
428
+ def _fwd(inp: jnp.ndarray, H_pre: jnp.ndarray):
429
+ out = _stream_aggregate_ffi_fwd(inp, H_pre, per_token)
430
+ # 保存输入用于反向
431
+ return out, (inp, H_pre)
432
+
433
+ def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
434
+ inp, H_pre = saved_vals
435
+ d_inp, d_H_pre = _stream_aggregate_ffi_bwd(grad, inp, H_pre, per_token)
436
+ # 返回两个梯度,对应forward的两个输入
437
+ return d_inp, d_H_pre
438
+
439
+ _kernel.defvjp(_fwd, _bwd)
440
+ return _kernel
441
+
442
+
443
+ # ---------- 公共API ----------
444
+ def stream_aggregate(inp: jnp.ndarray, H_pre: jnp.ndarray) -> jnp.ndarray:
445
+ """
446
+ JAX FFI版Stream Aggregate算子
447
+
448
+ 功能: Out = sum(inp * H_pre, axis=-2)
449
+ 高精度策略: 在float32空间完成乘法和累加,最后转回输入dtype
450
+
451
+ 参数:
452
+ inp: [B, T, n, C] 输入张量(任意dtype)
453
+ H_pre: [B, T, n] 或 [n] 权重张量(任意dtype)
454
+ - [B, T, n]: per-token权重,每个token有独立权重
455
+ - [n]: per-stream权重,所有token共享权重
456
+
457
+ 返回:
458
+ [B, T, C] 聚合结果,dtype与inp一致
459
+ """
460
+ # 形状检查
461
+ if inp.ndim != 4:
462
+ raise ValueError(f"Stream Aggregate需要4维输入,但得到{inp.ndim}维")
463
+ if H_pre.ndim not in [1, 3]:
464
+ raise ValueError(f"H_pre必须是1维或3维,但得到{H_pre.ndim}维")
465
+
466
+ B, T, n, C = inp.shape
467
+ if H_pre.ndim == 1:
468
+ if H_pre.shape[0] != n:
469
+ raise ValueError(f"全局权重H_pre的形状{n}与输入流数{n}不匹配")
470
+ per_token = False
471
+ else: # H_pre.ndim == 3
472
+ if H_pre.shape != (B, T, n):
473
+ raise ValueError(
474
+ f"Per-token权重H_pre形状{H_pre.shape}与输入形状{(B, T, n)}不匹配"
475
+ )
476
+ per_token = True
477
+
478
+ original_dtype = inp.dtype
479
+
480
+ inp = inp.astype(jnp.bfloat16)
481
+ H_pre = H_pre.astype(jnp.float32)
482
+ # 创建并执行kernel
483
+ kernel = _create_stream_aggregate_kernel(per_token)
484
+ checkpointed_kernel = jax.checkpoint(
485
+ kernel, policy=cp.save_anything_except_these_names(())
486
+ )
487
+
488
+ result = checkpointed_kernel(inp, H_pre)
489
+
490
+ return result.astype(original_dtype)
491
+
492
+
493
+ # 1. 在 register_ffi_target 部分追加
494
+ jax.ffi.register_ffi_target(
495
+ "stream_distribute_fwd",
496
+ jax.ffi.pycapsule(_LIB.StreamDistributeFwd),
497
+ platform="CUDA",
498
+ )
499
+ jax.ffi.register_ffi_target(
500
+ "stream_distribute_bwd",
501
+ jax.ffi.pycapsule(_LIB.StreamDistributeBwd),
502
+ platform="CUDA",
503
+ )
504
+
505
+
506
+ # 2. 实现核心逻辑
507
+ def _stream_distribute_ffi_fwd(inp: jnp.ndarray, H_post: jnp.ndarray) -> jnp.ndarray:
508
+ """内部FFI前向调用"""
509
+ B, T, C = inp.shape
510
+ n = H_post.shape[-1]
511
+ out_type = jax.ShapeDtypeStruct((B, T, n, C), jnp.bfloat16)
512
+
513
+ # 接口对齐:inp用bf16, H_post用f32
514
+ return jax.ffi.ffi_call(
515
+ "stream_distribute_fwd", out_type, vmap_method="broadcast_all"
516
+ )(inp.astype(jnp.bfloat16), H_post.astype(jnp.float32))
517
+
518
+
519
+ def _stream_distribute_ffi_bwd(
520
+ grad: jnp.ndarray, inp: jnp.ndarray, H_post: jnp.ndarray
521
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
522
+ """内部FFI反向调用"""
523
+ # 强制将梯度转为 bf16 匹配 FFI 签名,内部会转 float 计算
524
+ grad_bf16 = grad.astype(jnp.bfloat16)
525
+ inp_bf16 = inp.astype(jnp.bfloat16)
526
+ H_f32 = H_post.astype(jnp.float32)
527
+
528
+ d_inp_type = jax.ShapeDtypeStruct(inp.shape, jnp.bfloat16)
529
+ d_H_type = jax.ShapeDtypeStruct(H_post.shape, jnp.float32)
530
+
531
+ return jax.ffi.ffi_call(
532
+ "stream_distribute_bwd", (d_inp_type, d_H_type), vmap_method="broadcast_all"
533
+ )(grad_bf16, inp_bf16, H_f32)
534
+
535
+
536
+ def _create_stream_distribute_kernel():
537
+ """创建 Stream Distribute kernel"""
538
+
539
+ @jax.custom_vjp
540
+ def _kernel(inp: jnp.ndarray, H_post: jnp.ndarray) -> jnp.ndarray:
541
+ return _stream_distribute_ffi_fwd(inp, H_post)
542
+
543
+ def _fwd(inp: jnp.ndarray, H_post: jnp.ndarray):
544
+ out = _stream_distribute_ffi_fwd(inp, H_post)
545
+ return out, (inp, H_post)
546
+
547
+ def _bwd(saved_vals: Tuple[jnp.ndarray, jnp.ndarray], grad: jnp.ndarray):
548
+ inp, H_post = saved_vals
549
+ d_inp, d_H_post = _stream_distribute_ffi_bwd(grad, inp, H_post)
550
+ return d_inp, d_H_post
551
+
552
+ _kernel.defvjp(_fwd, _bwd)
553
+ return _kernel
554
+
555
+
556
+ # 3. 公共 API
557
+ def stream_distribute(inp: jnp.ndarray, H_post: jnp.ndarray) -> jnp.ndarray:
558
+ """
559
+ JAX FFI 版 Stream Distribute 算子 (1 -> n)
560
+ 功能: Out = inp[:, :, None, :] * H_post[:, :, :, None]
561
+
562
+ 参数:
563
+ inp: [B, T, C] 输入张量
564
+ H_post: [B, T, n] 权重张量
565
+ 返回:
566
+ [B, T, n, C] 分发后的多流张量,dtype 与 inp 一致
567
+ """
568
+ # 形状检查
569
+ if inp.ndim != 3 or H_post.ndim != 3:
570
+ raise ValueError(
571
+ f"stream_distribute 要求输入均为 3 维,得到 {inp.ndim} 和 {H_post.ndim}"
572
+ )
573
+
574
+ original_dtype = inp.dtype
575
+ inp = inp.astype(jnp.bfloat16)
576
+ H_post = H_post.astype(jnp.float32)
577
+ # 创建并执行 kernel
578
+ kernel = _create_stream_distribute_kernel()
579
+
580
+ # 统一设置永不重计算 (checkpoint policy)
581
+ checkpointed_kernel = jax.checkpoint(
582
+ kernel, policy=cp.save_anything_except_these_names(())
583
+ )
584
+
585
+ result = checkpointed_kernel(inp, H_post)
586
+
587
+ # 类型还原,避免梯度计算中出现不必要的类型漂移
588
+ return result.astype(original_dtype)
589
+
590
+
591
+ jax.ffi.register_ffi_target(
592
+ "mhc_post_op_fwd", jax.ffi.pycapsule(_LIB.MhcPostOpFwd), platform="CUDA"
593
+ )
594
+ jax.ffi.register_ffi_target(
595
+ "mhc_post_op_bwd", jax.ffi.pycapsule(_LIB.MhcPostOpBwd), platform="CUDA"
596
+ )
597
+
598
+
599
+ # 2. 内部 FFI 调用封装
600
+ def _mhc_post_op_ffi_fwd(layer_out, x_expanded, H_post, H_res):
601
+ out_type = jax.ShapeDtypeStruct(x_expanded.shape, jnp.bfloat16)
602
+ return jax.ffi.ffi_call("mhc_post_op_fwd", out_type, vmap_method="broadcast_all")(
603
+ layer_out.astype(jnp.bfloat16),
604
+ x_expanded.astype(jnp.bfloat16),
605
+ H_post.astype(jnp.float32),
606
+ H_res.astype(jnp.float32),
607
+ )
608
+
609
+
610
+ def _mhc_post_op_ffi_bwd(grad, layer_out, x_expanded, H_post, H_res):
611
+ d_lo_type = jax.ShapeDtypeStruct(layer_out.shape, jnp.bfloat16)
612
+ d_xe_type = jax.ShapeDtypeStruct(x_expanded.shape, jnp.bfloat16)
613
+ d_hp_type = jax.ShapeDtypeStruct(H_post.shape, jnp.float32)
614
+ d_hr_type = jax.ShapeDtypeStruct(H_res.shape, jnp.float32)
615
+
616
+ return jax.ffi.ffi_call(
617
+ "mhc_post_op_bwd",
618
+ (d_lo_type, d_xe_type, d_hp_type, d_hr_type),
619
+ vmap_method="broadcast_all",
620
+ )(
621
+ grad.astype(jnp.bfloat16),
622
+ layer_out.astype(jnp.bfloat16),
623
+ x_expanded.astype(jnp.bfloat16),
624
+ H_post.astype(jnp.float32),
625
+ H_res.astype(jnp.float32),
626
+ )
627
+
628
+
629
+ def _create_mhc_post_op_kernel():
630
+ @jax.custom_vjp
631
+ def _kernel(layer_out, x_expanded, H_post, H_res):
632
+ return _mhc_post_op_ffi_fwd(layer_out, x_expanded, H_post, H_res)
633
+
634
+ def _fwd(layer_out, x_expanded, H_post, H_res):
635
+ out = _mhc_post_op_ffi_fwd(layer_out, x_expanded, H_post, H_res)
636
+ return out, (layer_out, x_expanded, H_post, H_res)
637
+
638
+ def _bwd(saved_vals, grad):
639
+ lo, xe, hp, hr = saved_vals
640
+ return _mhc_post_op_ffi_bwd(grad, lo, xe, hp, hr)
641
+
642
+ _kernel.defvjp(_fwd, _bwd)
643
+ return _kernel
644
+
645
+
646
+ # 3. 公共 API
647
+ def mhc_post_op(
648
+ layer_out: jnp.ndarray,
649
+ x_expanded: jnp.ndarray,
650
+ H_post: jnp.ndarray,
651
+ H_res: jnp.ndarray,
652
+ ) -> jnp.ndarray:
653
+ """
654
+ mHC 后处理融合算子 (Fused Res-Mix + Post-Distribute)
655
+ 实现: x_next = (H_res @ x_expanded) + (layer_out * H_post)
656
+
657
+ 参数:
658
+ layer_out: [B, T, C] 核心层输出
659
+ x_expanded: [B, T, n, C] 之前的扩展流
660
+ H_post: [B, T, n] 分发权重
661
+ H_res: [B, T, n, n] 混合矩阵
662
+ 返回:
663
+ [B, T, n, C] 更新后的流,dtype 与 x_expanded 一致
664
+ """
665
+ original_dtype = x_expanded.dtype
666
+ layer_out = layer_out.astype(jnp.bfloat16)
667
+ x_expanded = x_expanded.astype(jnp.bfloat16)
668
+ H_post = H_post.astype(jnp.float32)
669
+ H_res = H_res.astype(jnp.float32)
670
+ kernel = _create_mhc_post_op_kernel()
671
+ # 强制 checkpoint 以节省显存并避免冗余重计算
672
+ checkpointed_kernel = jax.checkpoint(
673
+ kernel, policy=cp.save_anything_except_these_names(())
674
+ )
675
+
676
+ result = checkpointed_kernel(layer_out, x_expanded, H_post, H_res)
677
+ return result.astype(original_dtype)
678
+
679
+
680
+ # ---------- 在 register_ffi_target 部分追加 ----------
681
+ jax.ffi.register_ffi_target(
682
+ "mhc_pre_op_fwd", jax.ffi.pycapsule(_LIB.MhcPreOpFwd), platform="CUDA"
683
+ )
684
+ jax.ffi.register_ffi_target(
685
+ "mhc_pre_op_bwd", jax.ffi.pycapsule(_LIB.MhcPreOpBwd), platform="CUDA"
686
+ )
687
+
688
+
689
+ # ---------- MHC Pre-Op 核心实现 ----------
690
+ def _mhc_pre_op_ffi_fwd(
691
+ x_expanded, h_pre_raw, h_post_raw, h_res_raw, sinkhorn_iters, eps
692
+ ):
693
+ """内部FFI前向调用"""
694
+ # x_expanded: [B, T, n, C]
695
+ # h_pre_raw: [B, T, n]
696
+ # h_post_raw: [B, T, n]
697
+ # h_res_raw: [B, T, n, n] 或 [B, T, n*n]
698
+
699
+ # 展平 h_res_raw 以匹配 C++ 接口 (期望 [B, T, n*n])
700
+ if h_res_raw.ndim == 4:
701
+ h_res_raw_flat = h_res_raw.reshape(h_res_raw.shape[0], h_res_raw.shape[1], -1)
702
+ else:
703
+ h_res_raw_flat = h_res_raw
704
+
705
+ B, T, n, C = x_expanded.shape
706
+
707
+ # 定义输出形状 (H_res 由 C++ 返回展平格式)
708
+ out_type_x_layer_in = jax.ShapeDtypeStruct((B, T, C), jnp.bfloat16)
709
+ out_type_H = jax.ShapeDtypeStruct((B, T, n), jnp.float32)
710
+ out_type_H_res_flat = jax.ShapeDtypeStruct((B, T, n * n), jnp.float32)
711
+
712
+ # 调用 FFI 前向 (返回 4 个张量)
713
+ x_layer_in, H_pre, H_post, H_res_flat = jax.ffi.ffi_call(
714
+ "mhc_pre_op_fwd",
715
+ (out_type_x_layer_in, out_type_H, out_type_H, out_type_H_res_flat),
716
+ vmap_method="broadcast_all",
717
+ )(
718
+ x_expanded.astype(jnp.bfloat16),
719
+ h_pre_raw.astype(jnp.float32),
720
+ h_post_raw.astype(jnp.float32),
721
+ h_res_raw_flat.astype(jnp.float32),
722
+ sinkhorn_iters=sinkhorn_iters,
723
+ eps=eps,
724
+ )
725
+
726
+ # 将 H_res 重塑回 4D
727
+ H_res = H_res_flat.reshape(B, T, n, n)
728
+ return x_layer_in, H_pre, H_post, H_res
729
+
730
+
731
+ def _mhc_pre_op_ffi_bwd(
732
+ grad_layer_in,
733
+ grad_H_post,
734
+ grad_H_res,
735
+ x_expanded,
736
+ H_pre,
737
+ H_post,
738
+ H_res_out,
739
+ h_res_raw,
740
+ sinkhorn_iters,
741
+ eps,
742
+ ):
743
+ """内部FFI反向调用"""
744
+ # 展平梯度与残差以匹配 C++ 接口
745
+ grad_H_res_flat = grad_H_res.reshape(grad_H_res.shape[0], grad_H_res.shape[1], -1)
746
+ H_res_out_flat = H_res_out.reshape(H_res_out.shape[0], H_res_out.shape[1], -1)
747
+
748
+ # h_res_raw 是来自用户的原始输入,可能为 4D
749
+ if h_res_raw.ndim == 4:
750
+ h_res_raw_flat = h_res_raw.reshape(h_res_raw.shape[0], h_res_raw.shape[1], -1)
751
+ else:
752
+ h_res_raw_flat = h_res_raw
753
+
754
+ B, T, n, C = x_expanded.shape
755
+ d_x_shape = (B, T, n, C)
756
+ d_h_shape = (B, T, n)
757
+ d_h_res_shape = h_res_raw_flat.shape # 展平格式
758
+
759
+ d_x_type = jax.ShapeDtypeStruct(d_x_shape, jnp.bfloat16)
760
+ d_h_type = jax.ShapeDtypeStruct(d_h_shape, jnp.float32)
761
+ d_h_res_type = jax.ShapeDtypeStruct(d_h_res_shape, jnp.float32)
762
+
763
+ # 调用 FFI 反向 (返回 4 个梯度)
764
+ d_x_expanded, d_h_pre_raw, d_h_post_raw, d_h_res_raw_flat = jax.ffi.ffi_call(
765
+ "mhc_pre_op_bwd",
766
+ (d_x_type, d_h_type, d_h_type, d_h_res_type),
767
+ vmap_method="broadcast_all",
768
+ )(
769
+ grad_layer_in.astype(jnp.bfloat16),
770
+ grad_H_post.astype(jnp.float32),
771
+ grad_H_res_flat.astype(jnp.float32),
772
+ x_expanded.astype(jnp.bfloat16),
773
+ H_pre.astype(jnp.float32),
774
+ H_post.astype(jnp.float32),
775
+ H_res_out_flat.astype(jnp.float32),
776
+ h_res_raw_flat.astype(jnp.float32),
777
+ sinkhorn_iters=sinkhorn_iters,
778
+ eps=eps,
779
+ )
780
+
781
+ # 将 d_h_res_raw 重塑回 4D (若原始输入是 4D)
782
+ if h_res_raw.ndim == 4:
783
+ d_h_res_raw = d_h_res_raw_flat.reshape(B, T, n, n)
784
+ else:
785
+ d_h_res_raw = d_h_res_raw_flat
786
+
787
+ return d_x_expanded, d_h_pre_raw, d_h_post_raw, d_h_res_raw
788
+
789
+
790
+ def _create_mhc_pre_op_kernel(sinkhorn_iters: int, eps: float):
791
+ """创建MHC Pre-Op kernel(静态参数固化)"""
792
+ # 在闭包外部固化 Sinkhorn 参数为 NumPy 32 位标量
793
+ sinkhorn_iters_static = np.int32(sinkhorn_iters)
794
+ eps_static = np.float32(eps)
795
+
796
+ @jax.custom_vjp
797
+ def _kernel(x_expanded, h_pre_raw, h_post_raw, h_res_raw):
798
+ # 调用 FFI 前向,返回 4 个张量
799
+ x_layer_in, H_pre, H_post, H_res = _mhc_pre_op_ffi_fwd(
800
+ x_expanded,
801
+ h_pre_raw,
802
+ h_post_raw,
803
+ h_res_raw,
804
+ sinkhorn_iters_static,
805
+ eps_static,
806
+ )
807
+ # 只返回 3 个张量给用户 (PyTorch 版本不返回 H_pre)
808
+ return x_layer_in, H_post, H_res
809
+
810
+ def _fwd(x_expanded, h_pre_raw, h_post_raw, h_res_raw):
811
+ # 调用前向并保存残差
812
+ x_layer_in, H_pre, H_post, H_res = _mhc_pre_op_ffi_fwd(
813
+ x_expanded,
814
+ h_pre_raw,
815
+ h_post_raw,
816
+ h_res_raw,
817
+ sinkhorn_iters_static,
818
+ eps_static,
819
+ )
820
+ # 残差包含反向所需的所有张量 (包括 H_pre, H_post, H_res)
821
+ return (x_layer_in, H_post, H_res), (
822
+ x_expanded,
823
+ H_pre,
824
+ H_post,
825
+ H_res,
826
+ h_res_raw,
827
+ )
828
+
829
+ def _bwd(residuals, grads):
830
+ # 解包残差
831
+ x_expanded, H_pre, H_post, H_res, h_res_raw = residuals
832
+ # 解包输出梯度 (3 个梯度,对应前向的 3 个输出)
833
+ # grads: (grad_x_layer_in, grad_H_post, grad_H_res)
834
+ grad_layer_in = grads[0]
835
+ grad_H_post = grads[1]
836
+ grad_H_res = grads[2]
837
+
838
+ # 调用 FFI 反向
839
+ return _mhc_pre_op_ffi_bwd(
840
+ grad_layer_in,
841
+ grad_H_post,
842
+ grad_H_res,
843
+ x_expanded,
844
+ H_pre,
845
+ H_post,
846
+ H_res,
847
+ h_res_raw,
848
+ sinkhorn_iters_static,
849
+ eps_static,
850
+ )
851
+
852
+ _kernel.defvjp(_fwd, _bwd)
853
+ return _kernel
854
+
855
+
856
+ # ---------- 公共 API ----------
857
+ def mhc_pre_op(
858
+ x_expanded: jnp.ndarray,
859
+ h_pre_raw: jnp.ndarray,
860
+ h_post_raw: jnp.ndarray,
861
+ h_res_raw: jnp.ndarray,
862
+ num_iters: int = 20,
863
+ eps: float = 1e-8,
864
+ ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
865
+ """
866
+ mHC 前处理融合算子 (Fused Aggregate + Sigmoid + Sinkhorn)
867
+
868
+ 功能:
869
+ 1. H_pre = sigmoid(h_pre_raw)
870
+ 2. H_post = 2 * sigmoid(h_post_raw)
871
+ 3. H_res = Sinkhorn(h_res_raw)
872
+ 4. x_layer_in = sum(H_pre * x_expanded, axis=-2)
873
+
874
+ 参数:
875
+ x_expanded: [B, T, n, C] 输入张量(任意dtype)
876
+ h_pre_raw: [B, T, n] 预权重原始值(任意dtype)
877
+ h_post_raw: [B, T, n] 后权重原始值(任意dtype)
878
+ h_res_raw: [B, T, n, n] 残差权重原始值(任意dtype)
879
+ num_iters: Sinkhorn 迭代次数(编译期常量)
880
+ eps: 防止除零的小常数(编译期常量)
881
+
882
+ 返回:
883
+ tuple: (x_layer_in [B, T, C], H_post [B, T, n], H_res [B, T, n, n])
884
+ dtype 与输入一致
885
+ """
886
+ # 形状检查
887
+ if x_expanded.ndim != 4:
888
+ raise ValueError(f"x_expanded 需要 4 维,但得到 {x_expanded.ndim}")
889
+ B, T, n, C = x_expanded.shape
890
+
891
+ expected_h_shape = (B, T, n)
892
+ if h_pre_raw.shape != expected_h_shape:
893
+ raise ValueError(
894
+ f"h_pre_raw 形状 {h_pre_raw.shape} 与期望 {expected_h_shape} 不匹配"
895
+ )
896
+ if h_post_raw.shape != expected_h_shape:
897
+ raise ValueError(
898
+ f"h_post_raw 形状 {h_post_raw.shape} 与期望 {expected_h_shape} 不匹配"
899
+ )
900
+
901
+ # h_res_raw 可以是 4D [B,T,n,n] 或 3D [B,T,n*n]
902
+ if h_res_raw.ndim == 4:
903
+ if h_res_raw.shape != (B, T, n, n):
904
+ raise ValueError(
905
+ f"h_res_raw 4D 形状 {h_res_raw.shape} 与期望 {(B, T, n, n)} 不匹配"
906
+ )
907
+ elif h_res_raw.ndim == 3:
908
+ if h_res_raw.shape != (B, T, n * n):
909
+ raise ValueError(
910
+ f"h_res_raw 3D 形状 {h_res_raw.shape} 与期望 {(B, T, n * n)} 不匹配"
911
+ )
912
+ else:
913
+ raise ValueError(f"h_res_raw 必须是 3 维或 4 维,但得到 {h_res_raw.ndim}")
914
+
915
+ original_dtype_x = x_expanded.dtype
916
+ original_dtype_h = h_pre_raw.dtype # 假设所有 h 张量 dtype 相同
917
+
918
+ # 类型转换:激活用 bf16,参数用 fp32
919
+ x_expanded = x_expanded.astype(jnp.bfloat16)
920
+ h_pre_raw = h_pre_raw.astype(jnp.float32)
921
+ h_post_raw = h_post_raw.astype(jnp.float32)
922
+ h_res_raw = h_res_raw.astype(jnp.float32)
923
+
924
+ # 创建并执行 kernel
925
+ kernel = _create_mhc_pre_op_kernel(num_iters, eps)
926
+ checkpointed_kernel = jax.checkpoint(
927
+ kernel, policy=cp.save_anything_except_these_names(())
928
+ )
929
+
930
+ x_layer_in, H_post, H_res = checkpointed_kernel(
931
+ x_expanded, h_pre_raw, h_post_raw, h_res_raw
932
+ )
933
+
934
+ # 类型还原
935
+ return (
936
+ x_layer_in.astype(original_dtype_x),
937
+ H_post.astype(original_dtype_h),
938
+ H_res.astype(original_dtype_h),
939
+ )