rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,495 @@
1
+ Metadata-Version: 2.4
2
+ Name: rwkv-ops
3
+ Version: 0.6.1
4
+ Summary: RWKV operators for multiple backends (PyTorch, JAX, Keras)
5
+ Project-URL: Homepage, https://github.com/pass-lin/rwkv_ops
6
+ Author-email: pass-lin <qw_lin@qq.com>
7
+ License: Apache-2.0
8
+ License-File: LICENSE.txt
9
+ Classifier: License :: OSI Approved :: Apache Software License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Programming Language :: Python :: 3
12
+ Requires-Python: >=3.8
13
+ Requires-Dist: keras>=3.0
14
+ Description-Content-Type: text/markdown
15
+
16
+ [English Document](ENREADME.md)
17
+
18
+ # RWKV OPS 项目
19
+
20
+ > 由于 RWKV 将持续迭代,核心算子会随之更新。
21
+ > 本仓专门维护「算子」本身,不维护 layer 与 model;尽可能提供各框架的 GPU 算子。
22
+
23
+ ### 当前支持
24
+ | 算子类型 | 框架支持 |
25
+ |----------|----------|
26
+ | GPU 算子 | PyTorch、JAX|
27
+ | 原生算子 | PyTorch、JAX、TensorFlow、NumPy |
28
+
29
+ > 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
30
+ > 注意:本库依赖 `keras`。
31
+
32
+ ---
33
+
34
+ ## 安装
35
+
36
+ ```bash
37
+ pip install rwkv_ops
38
+ ```
39
+
40
+ 当然pip包对于编译的算子pip uninstal没法删干净,所有可以试着从源码安装
41
+ ```bash
42
+ git clone https://github.com/pass-lin/rwkv_ops.git
43
+ cd rwkv_ops
44
+ bash install.sh
45
+ ```
46
+ ---
47
+
48
+ ## 环境变量
49
+
50
+ | 变量名 | 含义 | 取值 | 默认值 | 优先级 |
51
+ |---|---|---|---|---|
52
+ | `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
53
+ | `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
54
+ | `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | `cuda` | — |
55
+
56
+ > 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。
57
+
58
+ ---
59
+
60
+
61
+ ## [MHC 算子](https://arxiv.org/abs/2512.24880)
62
+
63
+ 虽然和RWKV无关,但是我懒得再开一个包做分发了,就集成在这了吧。
64
+
65
+ ### 背景
66
+
67
+ 在多头架构中,传统做法通常是简单的线性变换或加权。MHC 引入了 **Sinkhorn-Knopp** 算法,将控制权重约束在双稳态矩阵空间内,从而保证了信息流动的守恒性和稳定性。由于这些操作涉及大量的中间变量和迭代计算,原生实现的显存占用极高。本库提供的 CUDA 算子通过 **算子融合(Operator Fusion)** 技术,显著降低了显存消耗并提升了运行速度。
68
+
69
+ 需要注意的是,这个仓库只提供最朴素的cuda实现。使用了https://github.com/AndreSlavescu/mHC.cu 和 Gemini协助完成。现在的代码可以成功通过测试
70
+
71
+ ---
72
+
73
+ ### MHC 算子列表
74
+
75
+ | 算子名称 | 核心功能 |
76
+ | --- | --- |
77
+ | `mhc_pre_op` | **前置融合算子**:处理输入流聚合与 Sinkhorn 矩阵准备。 |
78
+ | `mhc_post_op` | **后置融合算子**:处理层输出分发与多头残差融合。 |
79
+ | `sinkhorn_knopp` | 矩阵双稳态归一化(支持高精度反向迭代)。 |
80
+ | `rmsnorm` | 针对 MHC 输入分布优化的 RMS 归一化。 |
81
+ | `stream_aggregate` | 特征流加权聚合(多流转单流)。 |
82
+ | `stream_distribute` | 特征流加权分发(单流转多流)。 |
83
+ | `stream_mix` | 特征流间的动态混合(Cross-head Mixing)。 |
84
+
85
+ ---
86
+
87
+ ### MHC 典型集成流程
88
+
89
+ MHC 算子的使用流程,这是一个取代resnet的框架。我懒得写容器了,大概的使用流程如下所示:
90
+
91
+ ```python
92
+ from rwkv_ops import rmsnorm, mhc_pre_op, mhc_post_op
93
+
94
+ # 1. 归一化输入
95
+ x_norm = rmsnorm(x_expanded)
96
+
97
+ # 2. 生成原始控制参数 (通常通过 Linear 层)
98
+ # h_res_raw: [B, T, N, N], h_pre_raw/h_post_raw: [B, T, N]
99
+ h_res_raw, h_pre_raw, h_post_raw = linear_and_reshape(x_norm)
100
+
101
+ # 3. MHC 前置处理 (Fused Kernel)
102
+ x_layer_in, H_post, H_res = mhc_pre_op(
103
+ x_expanded, h_pre_raw, h_post_raw, h_res_raw, num_iters=20
104
+ )
105
+
106
+ # 4. 执行核心层逻辑 (x_layer_in 为聚合后的单流 [B, T, C])
107
+ layer_out = YourCoreLayer(x_layer_in)
108
+
109
+ # 5. MHC 后置处理 (Fused Kernel)
110
+ x_next = mhc_post_op(layer_out, x_expanded, H_post, H_res)
111
+
112
+ ```
113
+
114
+ ---
115
+
116
+ ### 算子详细定义
117
+
118
+ #### 1. `mhc_pre_op` (Fused Pre-computation)
119
+
120
+ **融合说明**:该算子等价于以下原生操作的融合:
121
+
122
+ * 对 `h_pre_raw` 执行 `Sigmoid` 得到前置门控。
123
+ * 对 `h_res_raw` 执行 `Exp` + `Sinkhorn-Knopp` 得到双稳态权重矩阵。
124
+ * 对 `x_expanded` 执行 `stream_aggregate`(按头聚合)。
125
+ * **融合优势**:避免了存储巨大的 Exp 矩阵和中间迭代状态,显存占用降低约 80%。
126
+
127
+ **接口定义**:
128
+
129
+ * **输入**:
130
+ * `x_expanded` : 展开后的 个特征头。
131
+ * `h_pre_raw` : 原始前置系数。
132
+ * `h_post_raw` : 原始后置系数。
133
+ * `h_res_raw` : 原始 Sinkhorn 输入。
134
+
135
+
136
+ * **返回**:
137
+ * `x_layer_in` : 融合后的层输入。
138
+ * `H_pre`, `H_post`, `H_res`: 供 `post_op` 及反向传播使用的归一化系数。
139
+
140
+
141
+
142
+ #### 2. `mhc_post_op` (Fused Post-computation)
143
+
144
+ **融合说明**:该算子等价于以下原生操作的融合:
145
+
146
+ * 执行 `stream_distribute` 将单流输出映射回多流。
147
+ * 执行 `stream_mix` 利用 `H_res` 进行头间信息交换。
148
+ * 对 `H_post` 执行 并作为残差门控。
149
+ * 执行多头残差加法:。
150
+
151
+ **接口定义**:
152
+
153
+ * **输入**: `layer_out` , `x_expanded` , `H_post`, `H_res`。
154
+ * **返回**: `x_next` 。
155
+
156
+
157
+ #### 3. `sinkhorn_knopp`
158
+
159
+ * **定义**: 。
160
+ * **特点**: CUDA 实现采用双向迭代,其梯度计算通过求解伴随状态方程实现,比直接对迭代过程进行自动微分更稳定且更省显存。
161
+
162
+ #### 4. `rmsnorm`
163
+
164
+ * **定义**: 标准的归一化与流操作,但在 CUDA 实现中针对 MHC 的特征分布(通常在 维度上)进行了特定的访存优化。参考mHC的实现,这个rmsnorm是不带参数的。
165
+
166
+
167
+ ### 5. `stream_aggregate`
168
+
169
+ **功能定义**:
170
+ 该算子执行加权空间压缩。它将 个独立的特征头(Streams)根据给定的权重向量进行线性加权求和,坍缩为一个统一的特征表示。
171
+
172
+ * **数学表达式**:,其中 ,。
173
+ * **等价操作**:相当于执行了 `torch.einsum('btn,btnc->btc', weights, x)`。
174
+
175
+ **接口定义**:
176
+
177
+ * **输入**:
178
+ * `x`: 多流特征张量。
179
+ * `weights`: 每个流对应的权重系数。
180
+
181
+
182
+ * **返回**:
183
+ * `out`: 聚合后的单流特征。
184
+
185
+
186
+
187
+ ---
188
+
189
+ ### 6. `stream_distribute`
190
+
191
+ **功能定义**:
192
+ 该算子执行特征的空间广播与重加权。它将一个单流特征复制到 个通道,并分别乘以对应的分发权重。
193
+
194
+ * **数学表达式**:,其中 。
195
+ * **等价操作**:相当于执行了 `x.unsqueeze(2) * weights.unsqueeze(-1)`。
196
+
197
+ **接口定义**:
198
+
199
+ * **输入**:
200
+ * `x`: 待分发的单流特征(通常来自核心层输出)。
201
+ * `weights`: 分发到每个头的权重系数。
202
+
203
+
204
+ * **返回**:
205
+ * `out`: 分发后的多流特征张量。
206
+
207
+
208
+
209
+ ---
210
+
211
+ ### 7. `stream_mix` (流混合)
212
+
213
+ **功能定义**:
214
+ 该算子是 MHC 实现头间通信(Cross-head Communication)的关键。它利用一个 的变换矩阵(通常是 Sinkhorn 归一化后的矩阵),在不同的特征头之间进行线性重组。
215
+
216
+ * **数学表达式**:。
217
+ * **等价操作**:相当于执行了 `torch.einsum('btnm,btmc->btnc', gate, x)`。
218
+ * **物理意义**:实现了信息的“非局部”重分配,使得每个头都能吸收来自其他头的信息。
219
+
220
+ **接口定义**:
221
+
222
+ * **输入**:
223
+ * `x`: 原始多流特征。
224
+ * `gate`: 混合矩阵(控制信息流动的开关与强度)。
225
+
226
+
227
+ * **返回**:
228
+ * `out`: 混合后的多流特征。
229
+
230
+
231
+
232
+ ---
233
+
234
+ ### MHC 实现状态
235
+
236
+ | Framework | cuda | triton | native |
237
+ | --- | --- | --- | --- |
238
+ | **PyTorch** | ✅ | ❌ | ✅ |
239
+ | **JAX** | ✅ | ❌ | ✅ |
240
+ | **TensorFlow** | ❌ | ❌ | ✅ |
241
+ | **NumPy** | ❌ | ❌ | ✅ |
242
+
243
+
244
+ ---
245
+
246
+
247
+
248
+ ```python
249
+ from rwkv_ops import generalized_delta_rule,generalized_delta_rule_inference # 或 from rwkv_ops import rwkv7_op,完全等价
250
+ #generalized_delta_rule_inference的入口和这个接口一致
251
+ #但是generalized_delta_rule_inference是没有梯度只支持inference的
252
+ def generalized_delta_rule(
253
+ r,
254
+ w,
255
+ k,
256
+ v,
257
+ a,
258
+ b,
259
+ initial_state=None,
260
+ output_final_state: bool = True,
261
+ head_first: bool = False,
262
+ ):
263
+ """
264
+ 分块 Delta Rule 注意力接口。
265
+
266
+ Args:
267
+ q: [B, T, H, K]
268
+ k: [B, T, H, K]
269
+ v: [B, T, H, V]
270
+ a: [B, T, H, K]
271
+ b: [B, T, H, K]
272
+ gk: [B, T, H, K] # decay term in log space!
273
+ initial_state: 初始状态 [N, H, K, V],N 为序列数
274
+ output_final_state: 是否返回最终状态
275
+ head_first: 是否 head-first 格式,不支持变长
276
+
277
+ Returns:
278
+ o: 输出 [B, T, H, V] 或 [B, H, T, V]
279
+ final_state: 最终状态 [N, H, K, V] 或 None
280
+ """
281
+ ```
282
+ generalized_delta_rule_inference和generalized_delta_rule的区别是前者没有梯度。因为不需要存储激活值,所以可以节省一部分显存。
283
+
284
+ ### cuda-kernel 特殊用法
285
+
286
+ - torch-cuda和jax-cuda kernel 下 `head_size` 也是一个 kernel 参数,默认为 64。
287
+ - 若 `head_size ≠ 64`,请使用:
288
+
289
+ ```python
290
+ from rwkv_ops import get_generalized_delta_rule
291
+
292
+ rwkv7_op, rwkv7_op_inference, USE_TRITON_KERNEL = get_generalized_delta_rule(
293
+ your_head_size, KERNEL_TYPE="cuda"
294
+ )
295
+ ```
296
+
297
+ - `USE_TRITON_KERNEL` 为常量,标记是否使用 chunkwise 算子。
298
+ - 两者 padding 处理逻辑不同:
299
+
300
+ ```python
301
+ if padding_mask is not None:
302
+ w += (1 - padding_mask) * -1e9
303
+ ```
304
+ - 对于上面的代码,基于循环的算子可以针对left pading和right pading都能成功处理。
305
+ - 而如果用的是chunkwise算子,建议统一left padding,如果是cuda或者原生,则都left right都能正确处理
306
+
307
+
308
+ ### rwkv7op 实现状态
309
+
310
+ | Framework | cuda | triton | native |
311
+ |-------------|------|--------|--------|
312
+ | PyTorch | ✅ | ✅ | ✅ |
313
+ | JAX | ✅ | ✅ | ✅ |
314
+ | TensorFlow | ⚠️ | ❌ | ✅ |
315
+ | NumPy | ❌ | ❌ | ✅ |
316
+ | MLX | ⚠️ | ❌ | ❌ |
317
+
318
+ ---
319
+ 1. `native` 为原生算子,无 chunkwise,速度慢且显存高。
320
+ 2. `triton` 使用的是chunkwise算法实现,速度快,并行度高,缺点是精度很差,介意勿用
321
+ 3. `cuda` 为基于 CUDA 的原生算子,速度很快,并且kernel内部使用fp32实现,所以精度也很高。缺点就是长序列的时候比较吃亏跑不满。
322
+ 4. tensorflow的CUDA实现只支持前向计算,是没有梯度的。并且这个是使用jax的cuda实现实现的,你需要保证你能够成功运行jax的cuda kernel。
323
+ 5. tensorflow kernel只支持eager
324
+ 6. 因为MLX还没合并到keras,所以原生算子暂不支持。但是我们提供了一个前向的算子。
325
+ ## rwkv7_op_rnn 使用方法
326
+
327
+ ### 背景
328
+ 这是RWKV7 OP的特殊情况,就是我们只考虑长度=1的情况。专门用于推理的decode阶段的加速
329
+
330
+ ### 使用方法
331
+
332
+ ```python
333
+ from rwkv_ops import rwkv7_op_rnn
334
+ def rwkv7_op_rnn(
335
+ r: jnp.ndarray,
336
+ w: jnp.ndarray,
337
+ k: jnp.ndarray,
338
+ v: jnp.ndarray,
339
+ a: jnp.ndarray,
340
+ b: jnp.ndarray,
341
+ initial_state: Optional[jnp.ndarray] = None,
342
+ output_final_state: bool = True,
343
+ head_first: bool = False,
344
+ )
345
+ """
346
+ 单步广义 delta 规则(仅前向)
347
+ 参数:
348
+ r,w,k,v,a,b: 输入张量,形状必须为 (B, 1, H, K) 或 (B, H, 1, K)
349
+ initial_state: 可选 (B, H, K, K) 初始状态,None 则零初始化
350
+ output_final_state: 是否同时返回最后状态
351
+ head_first: 是否将 head 维提前
352
+ 返回:
353
+ out: (B, 1, H, K) 与输入 dtype 一致
354
+ last_state: (B, H, K, K) 当 output_final_state=True
355
+ """
356
+ ```
357
+ ### rwkv7_op_rnn 实现状态
358
+
359
+ | Framework | cuda | triton | native |
360
+ |-------------|------|--------|--------|
361
+ | PyTorch | ✅ | ❌ | ✅ |
362
+ | JAX | ✅ | ❌ | ✅ |
363
+ | TensorFlow | ⚠️ | ❌ | ✅ |
364
+ | NumPy | ❌ | ❌ | ✅ |
365
+
366
+ 1. tf的cuda实现依赖于jax的cuda实现,所以需要安装jax
367
+ 2. native实现我们直接复用了rwkv7_op的native实现
368
+ 3. **这个算子没有梯度**
369
+ 4. tensorflow kernel只支持eager
370
+
371
+ ## rwkv6op 使用方法
372
+
373
+ ### PyTorch 使用注意事项
374
+
375
+ - 安装依赖:`keras`、`ninja`、完整的 CUDA 工具包。
376
+ - 若使用 VS Code + 虚拟环境调试,请务必在终端手动激活虚拟环境,再运行代码,否则 ninja 可能无法工作。
377
+ - 虽然 PyTorch 在「虚拟环境中的 CUDA 版本」与「全局 CUDA 版本」不一致时仍可正常运行,但强烈建议保持一致。
378
+ - PyTorch 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
379
+
380
+ ### JAX 使用注意事项
381
+
382
+ - 安装依赖:`keras`、`gcc`、`pybind11`、完整的 CUDA 工具包。
383
+ - 即使通过虚拟环境为 JAX 安装 CUDA,也必须在系统级安装完整 CUDA;两者版本需一致,以保证 JAX 并行编译速度。
384
+ - JAX 编译依赖 `/usr/local/cuda` 软链接,如不存在请手动创建:
385
+ ```shell
386
+ sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda
387
+ ```
388
+ - 确保 `nvcc -V` 正常输出,且 `which nvcc` 指向正确版本。
389
+ - JAX 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
390
+ - JAX ≥ 0.6.0 不再使用 CUDA 算子,默认使用原生算子;推荐 0.4.34。
391
+
392
+ ### TensorFlow 使用注意事项
393
+
394
+ - 仅提供基于原生 API 的 `RWKV6` 算子,仅用于推理,效率较低。
395
+
396
+ ---
397
+
398
+ ### 使用方法
399
+ 需要注意的是,和rwkv7写成函数的形式不一样,RWKV6的op是一个类,需要实例化。
400
+ ```python
401
+ from rwkv_ops import RWKV6_OP
402
+
403
+ operator = RWKV6_OP(
404
+ head_size=64, # 头大小,不确定时填 64
405
+ max_sequence_length=4096, # 训练最大序列长度;推理不受限
406
+ ops_loop=False # 可选:序列长度=1 时是否用上层 API 替代 CUDA
407
+ )
408
+ ```
409
+
410
+ #### 调用
411
+
412
+ ```python
413
+ y, y_state = operator(
414
+ r, k, v, w, u,
415
+ with_state=False, # 是否使用自定义初始状态 / 输出结束状态
416
+ init_state=None, # 初始状态 [n_state, num_heads, head_size, head_size]
417
+ state_map=None # int32 一维数组,长度=batch_size,定义 init_state 映射
418
+ )
419
+ ```
420
+
421
+ | 参数 | 形状 | 说明 |
422
+ |---|---|---|
423
+ | r, k, v, w | (batch_size, seq_len, hidden_size) | — |
424
+ | u | (num_heads, head_size) 或 (hidden_size,) | — |
425
+ | init_state | (n_state, num_heads, head_size, head_size) | n_state=1 时所有样本共用;n_state=batch_size 时一一对应 |
426
+ | state_map | (batch_size,) | 指定每个样本用到的 init_state 索引 |
427
+
428
+ | 返回值 | 形状 | 说明 |
429
+ |---|---|---|
430
+ | y | (batch_size, seq_len, hidden_size) | 输出 |
431
+ | y_state | (batch_size, num_heads, head_size, head_size) 或 None | 结束状态 |
432
+
433
+ ---
434
+
435
+ ### 分布式小贴士
436
+
437
+ - 算子本身无分布式支持;PyTorch 可直接用多线程分布式。
438
+ - JAX 需通过 `shard_map` 包装(示例):
439
+
440
+ ```python
441
+ import os
442
+ os.environ['KERAS_BACKEND'] = 'jax'
443
+
444
+ import jax, jax.numpy as jnp
445
+ from jax.experimental.shard_map import shard_map
446
+ from jax.sharding import Mesh, PartitionSpec as P
447
+ from functools import partial
448
+ from rwkv_ops import RWKV6_OP
449
+
450
+ batch_size, seq_length = 24, 512
451
+ head_size, num_heads = 64, 32
452
+ hidden_size = head_size * num_heads
453
+
454
+ mesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))
455
+ device_ns = NamedSharding(mesh, P('device_axis'))
456
+
457
+ operator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)
458
+
459
+ @partial(shard_map,
460
+ mesh=mesh,
461
+ in_specs=(P('device_axis'),) * 5,
462
+ out_specs=(P('device_axis'), P('device_axis')),
463
+ check_rep=False)
464
+ def call_kernel(r, k, v, w, u):
465
+ # 去掉最外 device 维度
466
+ r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))
467
+ y, ys = operator(r, k, v, w, u, with_state=True)
468
+ return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)
469
+
470
+ # 构造输入并放置到对应设备
471
+ keys = jax.random.split(jax.random.PRNGKey(0), 5)
472
+ inputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]
473
+ inputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(
474
+ lambda x: jax.device_put(x, device_ns), inputs)
475
+ inputs_u = inputs_u[:, :, 0] # (devices, hidden_size)
476
+
477
+ # 可选:jax.jit(call_kernel, ...) 加速
478
+ outputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)
479
+
480
+ print(outputs_y.shape, outputs_y.sharding)
481
+ print(y_state.shape, y_state.sharding)
482
+ ```
483
+
484
+ ---
485
+
486
+ ### rwkv6op 实现状态
487
+
488
+ | Framework | cuda | triton | native |
489
+ |-------------|------|--------|--------|
490
+ | PyTorch | ✅ | ❌ | ✅ |
491
+ | JAX | ⚠️ | ❌ | ✅ |
492
+ | TensorFlow | ❌ | ❌ | ✅ |
493
+ | NumPy | ❌ | ❌ | ✅ |
494
+
495
+ ⚠️ JAX 的 CUDA 实现仅适用于 < 0.6.0,推荐 0.4.34。
@@ -0,0 +1,89 @@
1
+ rwkv_ops/__init__.py,sha256=ElOPSg7VdT0MfR8XHlJfkJ5REjzJrkl3h8z8r7GFuqQ,1328
2
+ rwkv_ops/mhc_kernel/__init__.py,sha256=ntgmb-My0GOyYVm8rRd3nsQ4zcOS3hwUzBd9-osqq1I,1302
3
+ rwkv_ops/mhc_kernel/native_keras_op.py,sha256=m-Zv2BwaLSGqmV9LwSOZspT-P65g9WHtdwIeePyvb38,6456
4
+ rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h,sha256=FrgrYsFFZiuLKjUDYcDufVWbjkzEpMXs4uofKuuPYkg,1589
5
+ rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh,sha256=h80ANUMCTS5ay3TfAqcMZEsriHEB1UTOHVem614-E8g,7225
6
+ rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh,sha256=nLt77BaVdUfGSzsbc9AfA4R9U4Xx2LdL2CJh4xsLGCQ,7105
7
+ rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh,sha256=d5ZSGypygRlUdRUZpw2ja1Jm7Webw2UNJv__AhKbXmQ,4891
8
+ rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh,sha256=EePKkq99cHIrOTJPf-2owpckfE5NBDEF6XAAC4Pgfu0,4755
9
+ rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh,sha256=ths8hF6IPxlSfIRuioZThmblttka9-Je1aE-O_1MzpA,5222
10
+ rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh,sha256=3kJX_qlx8Im5awvx7Lk36FZUhDJY45UYkyjPkx8M4Ak,3195
11
+ rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh,sha256=2nl9us8jqm2mT4PfOfM03z2R5VaBkuh8WHSQKcU3R38,5327
12
+ rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh,sha256=TpDbVuZ-auBWy9JMvotoxJlxaffsDxil47FTftoZs58,1911
13
+ rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt,sha256=AF-YrNsDDi0Jgw7BOjff0vsTeM91X4lxlI0vuE0on0I,1440
14
+ rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu,sha256=fhmA0SA0iqPCqtAzwYKHs1WkiDX89lIhkWzvILqcSUU,23816
15
+ rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py,sha256=xSNUmHdfbFrGbI9qMuDnwI96eJDJ4V6DrVim5W0ixtU,29950
16
+ rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu,sha256=_JKKXXWg33jK6xfDiMax6tYmWGiVI8BM0onHoA-eN4U,8058
17
+ rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp,sha256=-cYLybjaKplAuvOA09DhzoGSFqtnElswpajmbEnuInw,15602
18
+ rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py,sha256=T9IGyqSSbgZyLGniLm5ACzQFotkgeaOuMO2KcWKUGGQ,10344
19
+ rwkv_ops/rwkv6_kernel/__init__.py,sha256=ktIzkK6EUc2nonLQnl2NAjJj9kMt02i9zqfjFcnM_NQ,3647
20
+ rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=4SL93Z4mmuQldHtmwqTKcP7M-outTU5Rge2qgDGzwBg,29966
21
+ rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=c3ZSJ9xC6-PKr88pOhjmBximdhwmP1_i7UOcIdKB43c,3354
22
+ rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=Pv0WsBp5byTSwkYrYkHcJa3wftSsHHzfRzleKdmJayY,12915
23
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
24
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
25
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h,sha256=KYJiWmmig0Wh-zpiWV96J_be8jlyc38Ztd1iqNoqVFI,1501
26
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h,sha256=CMQclcyHaDL65v7dEBOYqNNQcV332fFXmVNe-F23mJo,1526
27
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu,sha256=t6Q8_M63eSlyOqcwYiGfI0HvlvQ_z0okBR4JNKqW5n0,20810
28
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
29
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
30
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h,sha256=4qAa3frGI1buJanudvLT94rycS1bxmRQIA8zSNa0hBI,1501
31
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h,sha256=CMQclcyHaDL65v7dEBOYqNNQcV332fFXmVNe-F23mJo,1526
32
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip,sha256=givSxPA7YfKGz75rOtN8TAjTxWWraVNgTGPZfAJsZsQ,20836
33
+ rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu,sha256=tfRbMQBkl_LT7EVaJ6KoWYcQ902ApCrS6zkjXldFZXY,12770
34
+ rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp,sha256=cyCTiF--4SQiDJu7Dy_NuEhSe1vyki6JS4I2rsvT714,6659
35
+ rwkv_ops/rwkv7_kernel/__init__.py,sha256=O85hoWD0CQHcUOxJeyNeWfBUrO-s2QAmQ3WxzSts8FE,4290
36
+ rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
37
+ rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
38
+ rwkv_ops/rwkv7_kernel/jax_op.py,sha256=C7jOvJ-ZWTFfCZBQNzMbqgoVHuDS2QCGlBsGEMM4Fn0,9140
39
+ rwkv_ops/rwkv7_kernel/mlx_op.py,sha256=SA8ou8bvnKGlxISeFOnPtzchVTDzups_AOxMPcICrx8,3458
40
+ rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=g3pUPXApQ2C8A2PnNV_fVeCrNYzNORlWJbgiqVkh8z0,2998
41
+ rwkv_ops/rwkv7_kernel/tf_eager_kernel.py,sha256=J9PNt6ZxGTVgInOSqxSsjI9pbZh8K1u3g1n0QO1jumA,5213
42
+ rwkv_ops/rwkv7_kernel/torch_op.py,sha256=p22nLG18ISihSW4oOIjUVd0NLOKgek5c7FhyBzhqcfw,14067
43
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt,sha256=Dq4Ea8N2xOEej2jZpEw4MtFjUFgN0PUciejVOCSP-FM,1400
44
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu,sha256=XiDtQwRme6XrkPoDvyUoYBuTFU_HgjVCfck6vloR51o,14479
45
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py,sha256=E4KGpwS9eSsdSzY6S92lLCJZHFArOsrjZmDUvElMe4I,10522
46
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt,sha256=nrYXT4fBKk6s7quXTwDECxg32ybofH-7Mv-nAg6XtC4,1460
47
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu,sha256=yTOP8NlLZnFy5W8H77QBWmCpbnsKuvk7GbUao_aFcHI,5498
48
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py,sha256=ZQ98upEvlsKbbOFkTRweyVJCRoUFAFZTtO1e2VlaivQ,6079
49
+ rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py,sha256=uHsf_1qrtRK62IvhLuzefHGPWpHXmw1p0tqmwlHcptk,346
50
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py,sha256=2Voq1Bdzn0DFloiLvwINBk7akmxRWIqXIQeyafrJJGg,2138
51
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py,sha256=rhmglqHIIww7yPzaSBEp9ISxhhxoUbMtV51AUDyhUd8,1425
52
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py,sha256=JDfVZsMb8yMlMN3sKT3i3l3y1YQiQkyUjnSNyan5Fqc,1888
53
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py,sha256=g8b_81rIIjxeknYiklRGnox24rAvEvfKRKT-5nI0Euo,1992
54
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py,sha256=gQnToi1e1GZCvjWsEdWx6WakUN4Lc0JfaBSsSXYdN84,3369
55
+ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py,sha256=4SjQ_zTZvFxsBMeWOx0JGFg9EQ4vllvEx30EcvSZJzI,853
56
+ rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py,sha256=NoOh2_hA_rdH5bmaNNMAdCgVPfWvQpf-Q8BqF926jrw,667
57
+ rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py,sha256=PAMtE6wCW2Hz39oiHLGqhxY77csQAMYdNP2najDO_Jg,1407
58
+ rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py,sha256=8jyRxE8G0Q32MyGR-AsXnyBanWfZRb1WnNEHAVRptVE,1822
59
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu,sha256=t31pfvc7Lg8-slU7HyjZ-Uj5FV8DEDVkLvnQeNAb574,7957
60
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp,sha256=WC0bMakl4rPe9mIMJ5CDxqPc1-vSjwDAdo17CveMpLs,3524
61
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py,sha256=DUDvED2gWNS5P85if9gJg5OuAMvWIerz-xgKT6xS15A,7900
62
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu,sha256=81x9kGJx-Bm69CqFtWNsRvM4qnSEhlO_Yr0rv0vvgwo,2853
63
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp,sha256=B-fboKJFaP0ELdSz-K4T38EKxHJwxg7MM_Dpskw-YSk,1953
64
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py,sha256=QH5lY9K4sfMYxOgnzU_sP0EkiYEO6I5ONUEXAuTjmPc,3480
65
+ rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py,sha256=_u1srIATeoHKlVTVWbWXdpkjaggugl9y-Kx_Y4pYdIY,430
66
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py,sha256=CWtotXkVvHz4-rkuOqWh6zKy95jwimS9If6SU45ylW0,2103
67
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py,sha256=4RJbyUTO23OxwH1rGVxeBiBVZKNHpPL_tJ7MFoDCIts,1475
68
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py,sha256=zo6l0ZZUhXFu8wEFD76I0zSqFT9IXFKUKtyeaSwk380,1795
69
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py,sha256=0ucN1U0EDTDqcyTPLLcsAX6FLTf2E_3toOY9p81gWYE,1858
70
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py,sha256=ioPrS0NYQhpFk1j8rAxqtbwpx1CwjJQnrJEBDqVy-As,3283
71
+ rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py,sha256=54yoa3NpV64H-koURt-hUWpFHhUjwXpGvXPp2_ETCnw,825
72
+ rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py,sha256=hQkpyaa0eUyB4V3UVks7l1_dHwOrbump0FZILityBKw,611
73
+ rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py,sha256=gk6QdoT1oq5B8Hp8Ak-SGqHm8CEj3MErUeWcRsaaOQM,1470
74
+ rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py,sha256=YNhQwBmnVEmoBBPopJE0EWNZqDR9PTA9EzjqHOCwh3Y,1672
75
+ rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py,sha256=q78MWFX9S4zV2MmK5tquEaNEQLJHZTq70CvWsFwkigU,880
76
+ rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py,sha256=WiBkK_S1-DrNIRKcb8286r9skUvif_wp-VgeXoIszFI,11918
77
+ rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py,sha256=MCsnWDlwsaIaULY6LUr4V80zeaPUu4RNI0CBX6O1C4o,5721
78
+ rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py,sha256=EsToMXCYkTQuKIDALyxa_siydxvmqoYHNttDWvmxMlY,4792
79
+ rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py,sha256=KK1RX4T3EVAHe7OIRSQu9C4B8hw5fcOkqUiBSJQ0I5s,4839
80
+ rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py,sha256=sChu_EOqRyDBRspkV2awVzo9JAChAfNk51_lGyW4W6I,11059
81
+ rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py,sha256=medawQuGAVsbZ7kGW7mKyzrkmZ--zibN1nQWT1K9rqY,3428
82
+ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVIi6ulMse-HlXheNTg,2182
83
+ rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
84
+ rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
85
+ rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
86
+ rwkv_ops-0.6.1.dist-info/METADATA,sha256=g4-2SR9ah3U8TT-tkl87ZbCDEqUFH3d0mFp80v9-O3E,16811
87
+ rwkv_ops-0.6.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
88
+ rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
89
+ rwkv_ops-0.6.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.28.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any