rwkv-ops 0.1.1__tar.gz → 0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops-0.2/PKG-INFO +257 -0
- rwkv_ops-0.2/README.md +237 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/__init__.py +3 -1
- rwkv_ops-0.2/rwkv_ops/rwkv6_kernel/__init__.py +126 -0
- rwkv_ops-0.2/rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +724 -0
- rwkv_ops-0.2/rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +86 -0
- rwkv_ops-0.2/rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/__init__.py +3 -7
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_op.py +67 -29
- rwkv_ops-0.2/rwkv_ops.egg-info/PKG-INFO +257 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops.egg-info/SOURCES.txt +4 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/setup.py +3 -3
- rwkv_ops-0.1.1/PKG-INFO +0 -118
- rwkv_ops-0.1.1/README.md +0 -98
- rwkv_ops-0.1.1/rwkv_ops.egg-info/PKG-INFO +0 -118
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/LICENSE.txt +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/jax_op.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/native_keras_op.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops.egg-info/dependency_links.txt +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops.egg-info/requires.txt +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/rwkv_ops.egg-info/top_level.txt +0 -0
- {rwkv_ops-0.1.1 → rwkv_ops-0.2}/setup.cfg +0 -0
rwkv_ops-0.2/PKG-INFO
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: rwkv_ops
|
|
3
|
+
Version: 0.2
|
|
4
|
+
Home-page: https://github.com/pass-lin/rwkv_ops
|
|
5
|
+
License: Apache 2.0
|
|
6
|
+
Keywords: rwkv implement for multi backend
|
|
7
|
+
Classifier: Development Status :: 3 - Alpha
|
|
8
|
+
Classifier: Intended Audience :: Developers
|
|
9
|
+
Classifier: Intended Audience :: Science/Research
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
License-File: LICENSE.txt
|
|
20
|
+
|
|
21
|
+
[English Document](ENREADME.md)
|
|
22
|
+
|
|
23
|
+
# RWKV OPS 项目
|
|
24
|
+
|
|
25
|
+
> 由于 RWKV 将持续迭代,核心算子会随之更新。
|
|
26
|
+
> 本仓专门维护「算子」本身,不维护 layer 与 model;尽可能提供各框架的 GPU 算子。
|
|
27
|
+
|
|
28
|
+
### 当前支持
|
|
29
|
+
| 算子类型 | 框架支持 |
|
|
30
|
+
|----------|----------|
|
|
31
|
+
| GPU 算子 | PyTorch、JAX(TensorFlow 待 Google 支持 Triton 后上线) |
|
|
32
|
+
| 原生算子 | PyTorch、JAX、TensorFlow、NumPy |
|
|
33
|
+
|
|
34
|
+
> 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
|
|
35
|
+
> 注意:本库依赖 `keras`。
|
|
36
|
+
|
|
37
|
+
---
|
|
38
|
+
|
|
39
|
+
## 安装
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
pip install rwkv_ops
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
---
|
|
46
|
+
|
|
47
|
+
## 环境变量
|
|
48
|
+
|
|
49
|
+
| 变量名 | 含义 | 取值 | 默认值 | 优先级 |
|
|
50
|
+
|---|---|---|---|---|
|
|
51
|
+
| `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
|
|
52
|
+
| `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
|
|
53
|
+
| `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | — | — |
|
|
54
|
+
|
|
55
|
+
> 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。
|
|
56
|
+
> `native` 为原生算子,无 chunkwise,速度慢且显存高。
|
|
57
|
+
|
|
58
|
+
---
|
|
59
|
+
|
|
60
|
+
## rwkv7op 使用方法
|
|
61
|
+
|
|
62
|
+
```python
|
|
63
|
+
from rwkv_ops import generalized_delta_rule # 或 from rwkv_ops import rwkv7_op,完全等价
|
|
64
|
+
|
|
65
|
+
def generalized_delta_rule(
|
|
66
|
+
r,
|
|
67
|
+
w,
|
|
68
|
+
k,
|
|
69
|
+
v,
|
|
70
|
+
a,
|
|
71
|
+
b,
|
|
72
|
+
initial_state=None,
|
|
73
|
+
output_final_state: bool = True,
|
|
74
|
+
head_first: bool = False,
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
分块 Delta Rule 注意力接口。
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
q: [B, T, H, K]
|
|
81
|
+
k: [B, T, H, K]
|
|
82
|
+
v: [B, T, H, V]
|
|
83
|
+
a: [B, T, H, K]
|
|
84
|
+
b: [B, T, H, K]
|
|
85
|
+
gk: [B, T, H, K] # decay term in log space!
|
|
86
|
+
initial_state: 初始状态 [N, H, K, V],N 为序列数
|
|
87
|
+
output_final_state: 是否返回最终状态
|
|
88
|
+
head_first: 是否 head-first 格式,不支持变长
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
o: 输出 [B, T, H, V] 或 [B, H, T, V]
|
|
92
|
+
final_state: 最终状态 [N, H, K, V] 或 None
|
|
93
|
+
"""
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
### torch-cuda 特殊用法
|
|
97
|
+
|
|
98
|
+
- torch-cuda 下 `head_size` 也是一个 kernel 参数,默认为 64。
|
|
99
|
+
- 若 `head_size ≠ 64`,请使用:
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from rwkv_ops import get_generalized_delta_rule
|
|
103
|
+
|
|
104
|
+
generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
|
|
105
|
+
your_head_size, KERNEL_TYPE="cuda"
|
|
106
|
+
)
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
- `RWKV7_USE_KERNEL` 为常量,标记是否使用 chunkwise 算子。
|
|
110
|
+
- 两者 padding 处理逻辑不同:
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
if padding_mask is not None:
|
|
114
|
+
if RWKV7_USE_KERNEL:
|
|
115
|
+
w += (1 - padding_mask) * -1e9
|
|
116
|
+
else:
|
|
117
|
+
w = w * padding_mask + 1 - padding_mask
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
---
|
|
121
|
+
|
|
122
|
+
### rwkv7op 实现状态
|
|
123
|
+
|
|
124
|
+
| Framework | cuda | triton | native |
|
|
125
|
+
|-------------|------|--------|--------|
|
|
126
|
+
| PyTorch | ✅ | ✅ | ✅ |
|
|
127
|
+
| JAX | ❌ | ✅ | ✅ |
|
|
128
|
+
| TensorFlow | ❌ | ❌ | ✅ |
|
|
129
|
+
| NumPy | ❌ | ❌ | ✅ |
|
|
130
|
+
|
|
131
|
+
---
|
|
132
|
+
|
|
133
|
+
## rwkv6op 使用方法
|
|
134
|
+
|
|
135
|
+
### PyTorch 使用注意事项
|
|
136
|
+
|
|
137
|
+
- 安装依赖:`keras`、`ninja`、完整的 CUDA 工具包。
|
|
138
|
+
- 若使用 VS Code + 虚拟环境调试,请务必在终端手动激活虚拟环境,再运行代码,否则 ninja 可能无法工作。
|
|
139
|
+
- 虽然 PyTorch 在「虚拟环境中的 CUDA 版本」与「全局 CUDA 版本」不一致时仍可正常运行,但强烈建议保持一致。
|
|
140
|
+
- PyTorch 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
|
|
141
|
+
|
|
142
|
+
### JAX 使用注意事项
|
|
143
|
+
|
|
144
|
+
- 安装依赖:`keras`、`gcc`、`pybind11`、完整的 CUDA 工具包。
|
|
145
|
+
- 即使通过虚拟环境为 JAX 安装 CUDA,也必须在系统级安装完整 CUDA;两者版本需一致,以保证 JAX 并行编译速度。
|
|
146
|
+
- JAX 编译依赖 `/usr/local/cuda` 软链接,如不存在请手动创建:
|
|
147
|
+
```shell
|
|
148
|
+
sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda
|
|
149
|
+
```
|
|
150
|
+
- 确保 `nvcc -V` 正常输出,且 `which nvcc` 指向正确版本。
|
|
151
|
+
- JAX 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
|
|
152
|
+
- JAX ≥ 0.6.0 不再使用 CUDA 算子,默认使用原生算子;推荐 0.4.34。
|
|
153
|
+
|
|
154
|
+
### TensorFlow 使用注意事项
|
|
155
|
+
|
|
156
|
+
- 仅提供基于原生 API 的 `RWKV6` 算子,仅用于推理,效率较低。
|
|
157
|
+
|
|
158
|
+
---
|
|
159
|
+
|
|
160
|
+
### 使用方法
|
|
161
|
+
需要注意的是,和rwkv7写成函数的形式不一样,RWKV6的op是一个类,需要实例化。
|
|
162
|
+
```python
|
|
163
|
+
from rwkv_ops import RWKV6_OP
|
|
164
|
+
|
|
165
|
+
operator = RWKV6_OP(
|
|
166
|
+
head_size=64, # 头大小,不确定时填 64
|
|
167
|
+
max_sequence_length=4096, # 训练最大序列长度;推理不受限
|
|
168
|
+
ops_loop=False # 可选:序列长度=1 时是否用上层 API 替代 CUDA
|
|
169
|
+
)
|
|
170
|
+
```
|
|
171
|
+
|
|
172
|
+
#### 调用
|
|
173
|
+
|
|
174
|
+
```python
|
|
175
|
+
y, y_state = operator(
|
|
176
|
+
r, k, v, w, u,
|
|
177
|
+
with_state=False, # 是否使用自定义初始状态 / 输出结束状态
|
|
178
|
+
init_state=None, # 初始状态 [n_state, num_heads, head_size, head_size]
|
|
179
|
+
state_map=None # int32 一维数组,长度=batch_size,定义 init_state 映射
|
|
180
|
+
)
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
| 参数 | 形状 | 说明 |
|
|
184
|
+
|---|---|---|
|
|
185
|
+
| r, k, v, w | (batch_size, seq_len, hidden_size) | — |
|
|
186
|
+
| u | (num_heads, head_size) 或 (hidden_size,) | — |
|
|
187
|
+
| init_state | (n_state, num_heads, head_size, head_size) | n_state=1 时所有样本共用;n_state=batch_size 时一一对应 |
|
|
188
|
+
| state_map | (batch_size,) | 指定每个样本用到的 init_state 索引 |
|
|
189
|
+
|
|
190
|
+
| 返回值 | 形状 | 说明 |
|
|
191
|
+
|---|---|---|
|
|
192
|
+
| y | (batch_size, seq_len, hidden_size) | 输出 |
|
|
193
|
+
| y_state | (batch_size, num_heads, head_size, head_size) 或 None | 结束状态 |
|
|
194
|
+
|
|
195
|
+
---
|
|
196
|
+
|
|
197
|
+
### 分布式小贴士
|
|
198
|
+
|
|
199
|
+
- 算子本身无分布式支持;PyTorch 可直接用多线程分布式。
|
|
200
|
+
- JAX 需通过 `shard_map` 包装(示例):
|
|
201
|
+
|
|
202
|
+
```python
|
|
203
|
+
import os
|
|
204
|
+
os.environ['KERAS_BACKEND'] = 'jax'
|
|
205
|
+
|
|
206
|
+
import jax, jax.numpy as jnp
|
|
207
|
+
from jax.experimental.shard_map import shard_map
|
|
208
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
|
209
|
+
from functools import partial
|
|
210
|
+
from rwkv_ops import RWKV6_OP
|
|
211
|
+
|
|
212
|
+
batch_size, seq_length = 24, 512
|
|
213
|
+
head_size, num_heads = 64, 32
|
|
214
|
+
hidden_size = head_size * num_heads
|
|
215
|
+
|
|
216
|
+
mesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))
|
|
217
|
+
device_ns = NamedSharding(mesh, P('device_axis'))
|
|
218
|
+
|
|
219
|
+
operator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)
|
|
220
|
+
|
|
221
|
+
@partial(shard_map,
|
|
222
|
+
mesh=mesh,
|
|
223
|
+
in_specs=(P('device_axis'),) * 5,
|
|
224
|
+
out_specs=(P('device_axis'), P('device_axis')),
|
|
225
|
+
check_rep=False)
|
|
226
|
+
def call_kernel(r, k, v, w, u):
|
|
227
|
+
# 去掉最外 device 维度
|
|
228
|
+
r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))
|
|
229
|
+
y, ys = operator(r, k, v, w, u, with_state=True)
|
|
230
|
+
return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)
|
|
231
|
+
|
|
232
|
+
# 构造输入并放置到对应设备
|
|
233
|
+
keys = jax.random.split(jax.random.PRNGKey(0), 5)
|
|
234
|
+
inputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]
|
|
235
|
+
inputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(
|
|
236
|
+
lambda x: jax.device_put(x, device_ns), inputs)
|
|
237
|
+
inputs_u = inputs_u[:, :, 0] # (devices, hidden_size)
|
|
238
|
+
|
|
239
|
+
# 可选:jax.jit(call_kernel, ...) 加速
|
|
240
|
+
outputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)
|
|
241
|
+
|
|
242
|
+
print(outputs_y.shape, outputs_y.sharding)
|
|
243
|
+
print(y_state.shape, y_state.sharding)
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
---
|
|
247
|
+
|
|
248
|
+
### rwkv6op 实现状态
|
|
249
|
+
|
|
250
|
+
| Framework | cuda | triton | native |
|
|
251
|
+
|-------------|------|--------|--------|
|
|
252
|
+
| PyTorch | ✅ | ❌ | ✅ |
|
|
253
|
+
| JAX | ⚠️ | ❌ | ✅ |
|
|
254
|
+
| TensorFlow | ❌ | ❌ | ✅ |
|
|
255
|
+
| NumPy | ❌ | ❌ | ✅ |
|
|
256
|
+
|
|
257
|
+
⚠️ JAX 的 CUDA 实现仅适用于 < 0.6.0,推荐 0.4.34。
|
rwkv_ops-0.2/README.md
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
[English Document](ENREADME.md)
|
|
2
|
+
|
|
3
|
+
# RWKV OPS 项目
|
|
4
|
+
|
|
5
|
+
> 由于 RWKV 将持续迭代,核心算子会随之更新。
|
|
6
|
+
> 本仓专门维护「算子」本身,不维护 layer 与 model;尽可能提供各框架的 GPU 算子。
|
|
7
|
+
|
|
8
|
+
### 当前支持
|
|
9
|
+
| 算子类型 | 框架支持 |
|
|
10
|
+
|----------|----------|
|
|
11
|
+
| GPU 算子 | PyTorch、JAX(TensorFlow 待 Google 支持 Triton 后上线) |
|
|
12
|
+
| 原生算子 | PyTorch、JAX、TensorFlow、NumPy |
|
|
13
|
+
|
|
14
|
+
> 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
|
|
15
|
+
> 注意:本库依赖 `keras`。
|
|
16
|
+
|
|
17
|
+
---
|
|
18
|
+
|
|
19
|
+
## 安装
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
pip install rwkv_ops
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
---
|
|
26
|
+
|
|
27
|
+
## 环境变量
|
|
28
|
+
|
|
29
|
+
| 变量名 | 含义 | 取值 | 默认值 | 优先级 |
|
|
30
|
+
|---|---|---|---|---|
|
|
31
|
+
| `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
|
|
32
|
+
| `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
|
|
33
|
+
| `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | — | — |
|
|
34
|
+
|
|
35
|
+
> 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。
|
|
36
|
+
> `native` 为原生算子,无 chunkwise,速度慢且显存高。
|
|
37
|
+
|
|
38
|
+
---
|
|
39
|
+
|
|
40
|
+
## rwkv7op 使用方法
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
from rwkv_ops import generalized_delta_rule # 或 from rwkv_ops import rwkv7_op,完全等价
|
|
44
|
+
|
|
45
|
+
def generalized_delta_rule(
|
|
46
|
+
r,
|
|
47
|
+
w,
|
|
48
|
+
k,
|
|
49
|
+
v,
|
|
50
|
+
a,
|
|
51
|
+
b,
|
|
52
|
+
initial_state=None,
|
|
53
|
+
output_final_state: bool = True,
|
|
54
|
+
head_first: bool = False,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
分块 Delta Rule 注意力接口。
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
q: [B, T, H, K]
|
|
61
|
+
k: [B, T, H, K]
|
|
62
|
+
v: [B, T, H, V]
|
|
63
|
+
a: [B, T, H, K]
|
|
64
|
+
b: [B, T, H, K]
|
|
65
|
+
gk: [B, T, H, K] # decay term in log space!
|
|
66
|
+
initial_state: 初始状态 [N, H, K, V],N 为序列数
|
|
67
|
+
output_final_state: 是否返回最终状态
|
|
68
|
+
head_first: 是否 head-first 格式,不支持变长
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
o: 输出 [B, T, H, V] 或 [B, H, T, V]
|
|
72
|
+
final_state: 最终状态 [N, H, K, V] 或 None
|
|
73
|
+
"""
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
### torch-cuda 特殊用法
|
|
77
|
+
|
|
78
|
+
- torch-cuda 下 `head_size` 也是一个 kernel 参数,默认为 64。
|
|
79
|
+
- 若 `head_size ≠ 64`,请使用:
|
|
80
|
+
|
|
81
|
+
```python
|
|
82
|
+
from rwkv_ops import get_generalized_delta_rule
|
|
83
|
+
|
|
84
|
+
generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
|
|
85
|
+
your_head_size, KERNEL_TYPE="cuda"
|
|
86
|
+
)
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
- `RWKV7_USE_KERNEL` 为常量,标记是否使用 chunkwise 算子。
|
|
90
|
+
- 两者 padding 处理逻辑不同:
|
|
91
|
+
|
|
92
|
+
```python
|
|
93
|
+
if padding_mask is not None:
|
|
94
|
+
if RWKV7_USE_KERNEL:
|
|
95
|
+
w += (1 - padding_mask) * -1e9
|
|
96
|
+
else:
|
|
97
|
+
w = w * padding_mask + 1 - padding_mask
|
|
98
|
+
```
|
|
99
|
+
|
|
100
|
+
---
|
|
101
|
+
|
|
102
|
+
### rwkv7op 实现状态
|
|
103
|
+
|
|
104
|
+
| Framework | cuda | triton | native |
|
|
105
|
+
|-------------|------|--------|--------|
|
|
106
|
+
| PyTorch | ✅ | ✅ | ✅ |
|
|
107
|
+
| JAX | ❌ | ✅ | ✅ |
|
|
108
|
+
| TensorFlow | ❌ | ❌ | ✅ |
|
|
109
|
+
| NumPy | ❌ | ❌ | ✅ |
|
|
110
|
+
|
|
111
|
+
---
|
|
112
|
+
|
|
113
|
+
## rwkv6op 使用方法
|
|
114
|
+
|
|
115
|
+
### PyTorch 使用注意事项
|
|
116
|
+
|
|
117
|
+
- 安装依赖:`keras`、`ninja`、完整的 CUDA 工具包。
|
|
118
|
+
- 若使用 VS Code + 虚拟环境调试,请务必在终端手动激活虚拟环境,再运行代码,否则 ninja 可能无法工作。
|
|
119
|
+
- 虽然 PyTorch 在「虚拟环境中的 CUDA 版本」与「全局 CUDA 版本」不一致时仍可正常运行,但强烈建议保持一致。
|
|
120
|
+
- PyTorch 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
|
|
121
|
+
|
|
122
|
+
### JAX 使用注意事项
|
|
123
|
+
|
|
124
|
+
- 安装依赖:`keras`、`gcc`、`pybind11`、完整的 CUDA 工具包。
|
|
125
|
+
- 即使通过虚拟环境为 JAX 安装 CUDA,也必须在系统级安装完整 CUDA;两者版本需一致,以保证 JAX 并行编译速度。
|
|
126
|
+
- JAX 编译依赖 `/usr/local/cuda` 软链接,如不存在请手动创建:
|
|
127
|
+
```shell
|
|
128
|
+
sudo ln -sf /usr/local/cuda-12.4 /usr/local/cuda
|
|
129
|
+
```
|
|
130
|
+
- 确保 `nvcc -V` 正常输出,且 `which nvcc` 指向正确版本。
|
|
131
|
+
- JAX 限制:同一程序内只能实例化 **一个** `RWKV6_OP` 对象;算子线程安全(无状态),可在多处调用。
|
|
132
|
+
- JAX ≥ 0.6.0 不再使用 CUDA 算子,默认使用原生算子;推荐 0.4.34。
|
|
133
|
+
|
|
134
|
+
### TensorFlow 使用注意事项
|
|
135
|
+
|
|
136
|
+
- 仅提供基于原生 API 的 `RWKV6` 算子,仅用于推理,效率较低。
|
|
137
|
+
|
|
138
|
+
---
|
|
139
|
+
|
|
140
|
+
### 使用方法
|
|
141
|
+
需要注意的是,和rwkv7写成函数的形式不一样,RWKV6的op是一个类,需要实例化。
|
|
142
|
+
```python
|
|
143
|
+
from rwkv_ops import RWKV6_OP
|
|
144
|
+
|
|
145
|
+
operator = RWKV6_OP(
|
|
146
|
+
head_size=64, # 头大小,不确定时填 64
|
|
147
|
+
max_sequence_length=4096, # 训练最大序列长度;推理不受限
|
|
148
|
+
ops_loop=False # 可选:序列长度=1 时是否用上层 API 替代 CUDA
|
|
149
|
+
)
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
#### 调用
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
y, y_state = operator(
|
|
156
|
+
r, k, v, w, u,
|
|
157
|
+
with_state=False, # 是否使用自定义初始状态 / 输出结束状态
|
|
158
|
+
init_state=None, # 初始状态 [n_state, num_heads, head_size, head_size]
|
|
159
|
+
state_map=None # int32 一维数组,长度=batch_size,定义 init_state 映射
|
|
160
|
+
)
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
| 参数 | 形状 | 说明 |
|
|
164
|
+
|---|---|---|
|
|
165
|
+
| r, k, v, w | (batch_size, seq_len, hidden_size) | — |
|
|
166
|
+
| u | (num_heads, head_size) 或 (hidden_size,) | — |
|
|
167
|
+
| init_state | (n_state, num_heads, head_size, head_size) | n_state=1 时所有样本共用;n_state=batch_size 时一一对应 |
|
|
168
|
+
| state_map | (batch_size,) | 指定每个样本用到的 init_state 索引 |
|
|
169
|
+
|
|
170
|
+
| 返回值 | 形状 | 说明 |
|
|
171
|
+
|---|---|---|
|
|
172
|
+
| y | (batch_size, seq_len, hidden_size) | 输出 |
|
|
173
|
+
| y_state | (batch_size, num_heads, head_size, head_size) 或 None | 结束状态 |
|
|
174
|
+
|
|
175
|
+
---
|
|
176
|
+
|
|
177
|
+
### 分布式小贴士
|
|
178
|
+
|
|
179
|
+
- 算子本身无分布式支持;PyTorch 可直接用多线程分布式。
|
|
180
|
+
- JAX 需通过 `shard_map` 包装(示例):
|
|
181
|
+
|
|
182
|
+
```python
|
|
183
|
+
import os
|
|
184
|
+
os.environ['KERAS_BACKEND'] = 'jax'
|
|
185
|
+
|
|
186
|
+
import jax, jax.numpy as jnp
|
|
187
|
+
from jax.experimental.shard_map import shard_map
|
|
188
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
|
189
|
+
from functools import partial
|
|
190
|
+
from rwkv_ops import RWKV6_OP
|
|
191
|
+
|
|
192
|
+
batch_size, seq_length = 24, 512
|
|
193
|
+
head_size, num_heads = 64, 32
|
|
194
|
+
hidden_size = head_size * num_heads
|
|
195
|
+
|
|
196
|
+
mesh = Mesh(jax.devices('gpu'), axis_names=('device_axis',))
|
|
197
|
+
device_ns = NamedSharding(mesh, P('device_axis'))
|
|
198
|
+
|
|
199
|
+
operator = RWKV6_OP(head_size=head_size, max_sequence_length=seq_length)
|
|
200
|
+
|
|
201
|
+
@partial(shard_map,
|
|
202
|
+
mesh=mesh,
|
|
203
|
+
in_specs=(P('device_axis'),) * 5,
|
|
204
|
+
out_specs=(P('device_axis'), P('device_axis')),
|
|
205
|
+
check_rep=False)
|
|
206
|
+
def call_kernel(r, k, v, w, u):
|
|
207
|
+
# 去掉最外 device 维度
|
|
208
|
+
r, k, v, w, u = map(jnp.squeeze, (r, k, v, w, u))
|
|
209
|
+
y, ys = operator(r, k, v, w, u, with_state=True)
|
|
210
|
+
return jnp.expand_dims(y, 0), jnp.expand_dims(ys, 0)
|
|
211
|
+
|
|
212
|
+
# 构造输入并放置到对应设备
|
|
213
|
+
keys = jax.random.split(jax.random.PRNGKey(0), 5)
|
|
214
|
+
inputs = [jax.random.normal(k, (mesh.size, batch_size, seq_length, hidden_size)) for k in keys]
|
|
215
|
+
inputs_r, inputs_k, inputs_v, inputs_w, inputs_u = map(
|
|
216
|
+
lambda x: jax.device_put(x, device_ns), inputs)
|
|
217
|
+
inputs_u = inputs_u[:, :, 0] # (devices, hidden_size)
|
|
218
|
+
|
|
219
|
+
# 可选:jax.jit(call_kernel, ...) 加速
|
|
220
|
+
outputs_y, y_state = call_kernel(inputs_r, inputs_k, inputs_v, inputs_w, inputs_u)
|
|
221
|
+
|
|
222
|
+
print(outputs_y.shape, outputs_y.sharding)
|
|
223
|
+
print(y_state.shape, y_state.sharding)
|
|
224
|
+
```
|
|
225
|
+
|
|
226
|
+
---
|
|
227
|
+
|
|
228
|
+
### rwkv6op 实现状态
|
|
229
|
+
|
|
230
|
+
| Framework | cuda | triton | native |
|
|
231
|
+
|-------------|------|--------|--------|
|
|
232
|
+
| PyTorch | ✅ | ❌ | ✅ |
|
|
233
|
+
| JAX | ⚠️ | ❌ | ✅ |
|
|
234
|
+
| TensorFlow | ❌ | ❌ | ✅ |
|
|
235
|
+
| NumPy | ❌ | ❌ | ✅ |
|
|
236
|
+
|
|
237
|
+
⚠️ JAX 的 CUDA 实现仅适用于 < 0.6.0,推荐 0.4.34。
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
__version__ = "0.
|
|
1
|
+
__version__ = "0.2"
|
|
2
2
|
import os
|
|
3
3
|
|
|
4
4
|
KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "triton")
|
|
@@ -20,8 +20,10 @@ else:
|
|
|
20
20
|
assert KERNEL_TYPE in ["triton", "cuda", "native"]
|
|
21
21
|
assert BACKEND in ["torch", "jax", "numpy", "tensorflow"]
|
|
22
22
|
from .rwkv7_kernel import get_generalized_delta_rule
|
|
23
|
+
from .rwkv6_kernel import get_rwkv6_kernel
|
|
23
24
|
|
|
24
25
|
generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
|
|
25
26
|
KERNEL_TYPE=KERNEL_TYPE
|
|
26
27
|
)
|
|
27
28
|
rwkv7_op = generalized_delta_rule
|
|
29
|
+
RWKV6_OP = get_rwkv6_kernel(KERNEL_TYPE=KERNEL_TYPE)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# copy right from https://github.com/infiy-quine/RWKV6_Keras_Operator
|
|
2
|
+
import os
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
from distutils.util import strtobool
|
|
6
|
+
from packaging import version
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_rwkv6_kernel(KERNEL_TYPE="native"):
|
|
10
|
+
ops_kernel = True
|
|
11
|
+
if KERNEL_TYPE == "cuda":
|
|
12
|
+
if keras.config.backend() == "jax":
|
|
13
|
+
import jax
|
|
14
|
+
|
|
15
|
+
if version.parse(jax.__version__) < version.parse("0.6.0"):
|
|
16
|
+
from .jax_rwkv_kernel import RWKVKernelOperator as CudaOperator
|
|
17
|
+
|
|
18
|
+
ops_kernel = False
|
|
19
|
+
else:
|
|
20
|
+
print(
|
|
21
|
+
"The CUDA kernel of RWKV6 cannot be used in JAX > = 0.6 version. If you need to use the CUDA KERNEL of RWKV6 version, please downgrade the version. It is recommended to downgrade to 0.4.34"
|
|
22
|
+
)
|
|
23
|
+
print(
|
|
24
|
+
"RWKV6的CUDA kernel在JAX> = 0.6版本无法使用,如果需要使用RWKV6版本的CUDA KERNEL,请降级版本,建议降级到0.4.34"
|
|
25
|
+
)
|
|
26
|
+
CudaOperator = None
|
|
27
|
+
elif keras.config.backend() == "torch":
|
|
28
|
+
from .torch_rwkv_kernel import RWKVKernelOperator as CudaOperator
|
|
29
|
+
|
|
30
|
+
ops_kernel = False
|
|
31
|
+
else:
|
|
32
|
+
CudaOperator = None
|
|
33
|
+
else:
|
|
34
|
+
CudaOperator = None
|
|
35
|
+
from .ops_rwkv_kernel import RWKVKernelOperator as OpsOperator
|
|
36
|
+
|
|
37
|
+
class RWKVKernelOperator:
|
|
38
|
+
def __init__(self, head_size, max_sequence_length, ops_loop=False):
|
|
39
|
+
self.enbale_cuda = CudaOperator is not None
|
|
40
|
+
|
|
41
|
+
if self.enbale_cuda:
|
|
42
|
+
self.cuda_operator = CudaOperator(head_size, max_sequence_length)
|
|
43
|
+
|
|
44
|
+
self.ops_operator = OpsOperator(head_size, max_sequence_length)
|
|
45
|
+
|
|
46
|
+
self.ops_loop = ops_loop
|
|
47
|
+
|
|
48
|
+
def __call__(
|
|
49
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
50
|
+
):
|
|
51
|
+
seq_len = r.shape[1]
|
|
52
|
+
|
|
53
|
+
def call_parallel():
|
|
54
|
+
if self.enbale_cuda:
|
|
55
|
+
return self.cuda_operator(
|
|
56
|
+
r=r,
|
|
57
|
+
k=k,
|
|
58
|
+
v=v,
|
|
59
|
+
w=w,
|
|
60
|
+
u=u,
|
|
61
|
+
with_state=with_state,
|
|
62
|
+
init_state=init_state,
|
|
63
|
+
state_map=state_map,
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
return self.ops_operator(
|
|
67
|
+
r=r,
|
|
68
|
+
k=k,
|
|
69
|
+
v=v,
|
|
70
|
+
w=w,
|
|
71
|
+
u=u,
|
|
72
|
+
with_state=with_state,
|
|
73
|
+
init_state=init_state,
|
|
74
|
+
state_map=state_map,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def call_one_step():
|
|
78
|
+
return self.ops_operator(
|
|
79
|
+
r=r,
|
|
80
|
+
k=k,
|
|
81
|
+
v=v,
|
|
82
|
+
w=w,
|
|
83
|
+
u=u,
|
|
84
|
+
with_state=with_state,
|
|
85
|
+
init_state=init_state,
|
|
86
|
+
state_map=state_map,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if not self.ops_loop:
|
|
90
|
+
return ops.cond(
|
|
91
|
+
seq_len != 1 and not ops_kernel, call_parallel, call_one_step
|
|
92
|
+
)
|
|
93
|
+
else:
|
|
94
|
+
return call_parallel()
|
|
95
|
+
|
|
96
|
+
return RWKVKernelOperator
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# from .ops_rwkv_kernal import RWKVKernelOperator as OPSKernelOperator
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
新增三个参数
|
|
104
|
+
return_state 布尔类型 是否返回最终的state,如果想自定义init_state也需要启用这个开关
|
|
105
|
+
|
|
106
|
+
init_state
|
|
107
|
+
当init_state省缺时,则使用全零初始化BatchSize维度上的状态。
|
|
108
|
+
形状: (state_kinds,num_heads,head_size, head_size), 其中state_kinds为小于等于Batch_Size的正整数
|
|
109
|
+
精度: 在r为fp16时 init_state为fp32 其余时候类型与r相同
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
state_map
|
|
113
|
+
形状: (Batch_Size,)
|
|
114
|
+
精度: int64, list[int]
|
|
115
|
+
这个数组定义了state到r上每个Batch维度切片间的映射关系
|
|
116
|
+
取值范围: [0, state_kinds)
|
|
117
|
+
|
|
118
|
+
返回:
|
|
119
|
+
output, output_state
|
|
120
|
+
|
|
121
|
+
def __call__(self,r, k, v, w, u, return_state=False, init_state=None, state_map=None):
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
"""
|