rwkv-ops 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +5 -6
- rwkv_ops/rwkv6_kernel/__init__.py +0 -6
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
- rwkv_ops/rwkv7_kernel/__init__.py +80 -29
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
- rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/METADATA +28 -27
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/RECORD +30 -13
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/WHEEL +1 -2
- rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
- {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info/licenses}/LICENSE.txt +0 -0
|
@@ -1,23 +1,17 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: rwkv-ops
|
|
3
|
-
Version: 0.
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
Classifier: Intended Audience :: Science/Research
|
|
3
|
+
Version: 0.3.1
|
|
4
|
+
Summary: RWKV operators for multiple backends (PyTorch, JAX, Keras)
|
|
5
|
+
Project-URL: Homepage, https://github.com/pass-lin/rwkv_ops
|
|
6
|
+
Author-email: pass-lin <qw_lin@qq.com>
|
|
7
|
+
License: Apache-2.0
|
|
8
|
+
License-File: LICENSE.txt
|
|
10
9
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
10
|
Classifier: Operating System :: OS Independent
|
|
12
11
|
Classifier: Programming Language :: Python :: 3
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
16
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
17
|
-
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Requires-Dist: keras>=3.0
|
|
18
14
|
Description-Content-Type: text/markdown
|
|
19
|
-
License-File: LICENSE.txt
|
|
20
|
-
Requires-Dist: keras
|
|
21
15
|
|
|
22
16
|
[English Document](ENREADME.md)
|
|
23
17
|
|
|
@@ -29,7 +23,7 @@ Requires-Dist: keras
|
|
|
29
23
|
### 当前支持
|
|
30
24
|
| 算子类型 | 框架支持 |
|
|
31
25
|
|----------|----------|
|
|
32
|
-
| GPU 算子 | PyTorch、JAX
|
|
26
|
+
| GPU 算子 | PyTorch、JAX|
|
|
33
27
|
| 原生算子 | PyTorch、JAX、TensorFlow、NumPy |
|
|
34
28
|
|
|
35
29
|
> 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
|
|
@@ -43,6 +37,12 @@ Requires-Dist: keras
|
|
|
43
37
|
pip install rwkv_ops
|
|
44
38
|
```
|
|
45
39
|
|
|
40
|
+
当然pip包对于编译的算子pip uninstal没法删干净,所有可以试着从源码安装
|
|
41
|
+
```bash
|
|
42
|
+
git clone https://github.com/pass-lin/rwkv_ops.git
|
|
43
|
+
cd rwkv_ops
|
|
44
|
+
bash install.sh
|
|
45
|
+
```
|
|
46
46
|
---
|
|
47
47
|
|
|
48
48
|
## 环境变量
|
|
@@ -51,10 +51,9 @@ pip install rwkv_ops
|
|
|
51
51
|
|---|---|---|---|---|
|
|
52
52
|
| `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
|
|
53
53
|
| `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
|
|
54
|
-
| `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` |
|
|
54
|
+
| `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | `cuda` | — |
|
|
55
55
|
|
|
56
56
|
> 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。
|
|
57
|
-
> `native` 为原生算子,无 chunkwise,速度慢且显存高。
|
|
58
57
|
|
|
59
58
|
---
|
|
60
59
|
|
|
@@ -102,34 +101,36 @@ def generalized_delta_rule(
|
|
|
102
101
|
```python
|
|
103
102
|
from rwkv_ops import get_generalized_delta_rule
|
|
104
103
|
|
|
105
|
-
generalized_delta_rule,
|
|
104
|
+
generalized_delta_rule, USE_TRITON_KERNEL = get_generalized_delta_rule(
|
|
106
105
|
your_head_size, KERNEL_TYPE="cuda"
|
|
107
106
|
)
|
|
108
107
|
```
|
|
109
108
|
|
|
110
|
-
- `
|
|
109
|
+
- `USE_TRITON_KERNEL` 为常量,标记是否使用 chunkwise 算子。
|
|
111
110
|
- 两者 padding 处理逻辑不同:
|
|
112
111
|
|
|
113
112
|
```python
|
|
114
113
|
if padding_mask is not None:
|
|
115
|
-
|
|
116
|
-
w += (1 - padding_mask) * -1e9
|
|
117
|
-
else:
|
|
118
|
-
w = w * padding_mask + 1 - padding_mask
|
|
114
|
+
w += (1 - padding_mask) * -1e9
|
|
119
115
|
```
|
|
116
|
+
- 对于上面的代码,基于循环的算子可以针对left pading和right pading都能成功处理。
|
|
117
|
+
- 而如果用的是chunkwise算子,建议统一left padding,如果是cuda或者原生,则都left right都能正确处理
|
|
120
118
|
|
|
121
|
-
---
|
|
122
119
|
|
|
123
120
|
### rwkv7op 实现状态
|
|
124
121
|
|
|
125
122
|
| Framework | cuda | triton | native |
|
|
126
123
|
|-------------|------|--------|--------|
|
|
127
124
|
| PyTorch | ✅ | ✅ | ✅ |
|
|
128
|
-
| JAX |
|
|
129
|
-
| TensorFlow |
|
|
125
|
+
| JAX | ✅ | ✅ | ✅ |
|
|
126
|
+
| TensorFlow | ⚠️ | ❌ | ✅ |
|
|
130
127
|
| NumPy | ❌ | ❌ | ✅ |
|
|
131
128
|
|
|
132
129
|
---
|
|
130
|
+
> `native` 为原生算子,无 chunkwise,速度慢且显存高。
|
|
131
|
+
> `triton` 使用的是chunkwise算法实现,速度快,并行度高,缺点是精度很差,介意勿用
|
|
132
|
+
> `cuda` 为基于 CUDA 的原生算子,速度很快,并且kernel内部使用fp32实现,所以精度也很高。缺点就是长序列的时候比较吃亏跑不满。
|
|
133
|
+
> tensorflow的CUDA实现只支持前向计算,是没有梯度的。并且这个是使用jax的cuda实现实现的,你需要保证你能够成功运行jax的cuda kernel。
|
|
133
134
|
|
|
134
135
|
## rwkv6op 使用方法
|
|
135
136
|
|
|
@@ -1,14 +1,30 @@
|
|
|
1
|
-
rwkv_ops/__init__.py,sha256=
|
|
2
|
-
rwkv_ops/rwkv6_kernel/__init__.py,sha256=
|
|
3
|
-
rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=
|
|
4
|
-
rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=
|
|
5
|
-
rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=
|
|
6
|
-
rwkv_ops/
|
|
1
|
+
rwkv_ops/__init__.py,sha256=aQr_HZZ-tB6b_HTjnrbRNyPnExrq0qYReK7-UTqdFTA,855
|
|
2
|
+
rwkv_ops/rwkv6_kernel/__init__.py,sha256=ktIzkK6EUc2nonLQnl2NAjJj9kMt02i9zqfjFcnM_NQ,3647
|
|
3
|
+
rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=4SL93Z4mmuQldHtmwqTKcP7M-outTU5Rge2qgDGzwBg,29966
|
|
4
|
+
rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=c3ZSJ9xC6-PKr88pOhjmBximdhwmP1_i7UOcIdKB43c,3354
|
|
5
|
+
rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=Pv0WsBp5byTSwkYrYkHcJa3wftSsHHzfRzleKdmJayY,12915
|
|
6
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
|
|
7
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
|
|
8
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h,sha256=KYJiWmmig0Wh-zpiWV96J_be8jlyc38Ztd1iqNoqVFI,1501
|
|
9
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h,sha256=CMQclcyHaDL65v7dEBOYqNNQcV332fFXmVNe-F23mJo,1526
|
|
10
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu,sha256=t6Q8_M63eSlyOqcwYiGfI0HvlvQ_z0okBR4JNKqW5n0,20810
|
|
11
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
|
|
12
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
|
|
13
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h,sha256=4qAa3frGI1buJanudvLT94rycS1bxmRQIA8zSNa0hBI,1501
|
|
14
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h,sha256=CMQclcyHaDL65v7dEBOYqNNQcV332fFXmVNe-F23mJo,1526
|
|
15
|
+
rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip,sha256=givSxPA7YfKGz75rOtN8TAjTxWWraVNgTGPZfAJsZsQ,20836
|
|
16
|
+
rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu,sha256=tfRbMQBkl_LT7EVaJ6KoWYcQ902ApCrS6zkjXldFZXY,12770
|
|
17
|
+
rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp,sha256=cyCTiF--4SQiDJu7Dy_NuEhSe1vyki6JS4I2rsvT714,6659
|
|
18
|
+
rwkv_ops/rwkv7_kernel/__init__.py,sha256=1TFr3m5kLC0Ldv0HEnzt58_g84e7KD81R6s8TpmiWmI,8148
|
|
7
19
|
rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
|
|
8
20
|
rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
|
|
9
|
-
rwkv_ops/rwkv7_kernel/jax_op.py,sha256=
|
|
10
|
-
rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=
|
|
21
|
+
rwkv_ops/rwkv7_kernel/jax_op.py,sha256=C7jOvJ-ZWTFfCZBQNzMbqgoVHuDS2QCGlBsGEMM4Fn0,9140
|
|
22
|
+
rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=dCWdzuVZxAKHCBURZqgOLN3n_yKFFNX5uORlbvztH6w,2502
|
|
23
|
+
rwkv_ops/rwkv7_kernel/tf_eager_kernel.py,sha256=2t2uf1iNznYpYFlqt9REY0GwGeycYuaJl-4QFk2rJHc,4357
|
|
11
24
|
rwkv_ops/rwkv7_kernel/torch_op.py,sha256=jw_AvqshTAG4t9-MRqxFQNi_bTzxNbx3lwnMifPk8-8,14070
|
|
25
|
+
rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt,sha256=Dq4Ea8N2xOEej2jZpEw4MtFjUFgN0PUciejVOCSP-FM,1400
|
|
26
|
+
rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu,sha256=WePveEdUixaQA51hJUK8Sr7Q7jDTstybEWZczdjuGSo,9690
|
|
27
|
+
rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py,sha256=3lvCKIa9DO7MY3aZNyJM0AyHlQUvDKGsnYVr8MLl7Vg,7998
|
|
12
28
|
rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py,sha256=uHsf_1qrtRK62IvhLuzefHGPWpHXmw1p0tqmwlHcptk,346
|
|
13
29
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py,sha256=2Voq1Bdzn0DFloiLvwINBk7akmxRWIqXIQeyafrJJGg,2138
|
|
14
30
|
rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py,sha256=rhmglqHIIww7yPzaSBEp9ISxhhxoUbMtV51AUDyhUd8,1425
|
|
@@ -19,6 +35,8 @@ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py,sha256=4SjQ_zTZvFxsBMeWOx0JGFg9E
|
|
|
19
35
|
rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py,sha256=NoOh2_hA_rdH5bmaNNMAdCgVPfWvQpf-Q8BqF926jrw,667
|
|
20
36
|
rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py,sha256=PAMtE6wCW2Hz39oiHLGqhxY77csQAMYdNP2najDO_Jg,1407
|
|
21
37
|
rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py,sha256=8jyRxE8G0Q32MyGR-AsXnyBanWfZRb1WnNEHAVRptVE,1822
|
|
38
|
+
rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu,sha256=uH24A3Z8lItMRc7jq0ybswmwiJGKT3BsAlW18hg_7Gc,5040
|
|
39
|
+
rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp,sha256=Wk5QYvIM9m-YJdSEh6zSzVKaw1v2lphvupbwA0GHmGw,2201
|
|
22
40
|
rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py,sha256=_u1srIATeoHKlVTVWbWXdpkjaggugl9y-Kx_Y4pYdIY,430
|
|
23
41
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py,sha256=CWtotXkVvHz4-rkuOqWh6zKy95jwimS9If6SU45ylW0,2103
|
|
24
42
|
rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py,sha256=4RJbyUTO23OxwH1rGVxeBiBVZKNHpPL_tJ7MFoDCIts,1475
|
|
@@ -40,8 +58,7 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
|
|
|
40
58
|
rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
|
|
41
59
|
rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
|
|
42
60
|
rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
|
|
43
|
-
rwkv_ops-0.
|
|
44
|
-
rwkv_ops-0.
|
|
45
|
-
rwkv_ops-0.
|
|
46
|
-
rwkv_ops-0.
|
|
47
|
-
rwkv_ops-0.2.2.dist-info/RECORD,,
|
|
61
|
+
rwkv_ops-0.3.1.dist-info/METADATA,sha256=cCo4QhAeyVGbAMMO7FC0PxskVRYPmKU30FUYbrqoVnw,8853
|
|
62
|
+
rwkv_ops-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
63
|
+
rwkv_ops-0.3.1.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
64
|
+
rwkv_ops-0.3.1.dist-info/RECORD,,
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
rwkv_ops
|
|
File without changes
|