rwkv-ops 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (31) hide show
  1. rwkv_ops/__init__.py +5 -6
  2. rwkv_ops/rwkv6_kernel/__init__.py +0 -6
  3. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  4. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  5. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  6. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  7. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  8. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  9. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  10. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  11. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  12. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  13. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
  14. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
  15. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  16. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  17. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
  18. rwkv_ops/rwkv7_kernel/__init__.py +80 -29
  19. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  20. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
  21. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
  22. rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
  23. rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
  24. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
  25. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
  26. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
  27. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/METADATA +28 -27
  28. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/RECORD +30 -13
  29. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info}/WHEEL +1 -2
  30. rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
  31. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.1.dist-info/licenses}/LICENSE.txt +0 -0
@@ -1,23 +1,17 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: rwkv-ops
3
- Version: 0.2.2
4
- Home-page: https://github.com/pass-lin/rwkv_ops
5
- License: Apache 2.0
6
- Keywords: rwkv implement for multi backend
7
- Classifier: Development Status :: 3 - Alpha
8
- Classifier: Intended Audience :: Developers
9
- Classifier: Intended Audience :: Science/Research
3
+ Version: 0.3.1
4
+ Summary: RWKV operators for multiple backends (PyTorch, JAX, Keras)
5
+ Project-URL: Homepage, https://github.com/pass-lin/rwkv_ops
6
+ Author-email: pass-lin <qw_lin@qq.com>
7
+ License: Apache-2.0
8
+ License-File: LICENSE.txt
10
9
  Classifier: License :: OSI Approved :: Apache Software License
11
10
  Classifier: Operating System :: OS Independent
12
11
  Classifier: Programming Language :: Python :: 3
13
- Classifier: Programming Language :: Python :: 3.8
14
- Classifier: Programming Language :: Python :: 3.9
15
- Classifier: Programming Language :: Python :: 3.10
16
- Classifier: Programming Language :: Python :: 3.11
17
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
12
+ Requires-Python: >=3.8
13
+ Requires-Dist: keras>=3.0
18
14
  Description-Content-Type: text/markdown
19
- License-File: LICENSE.txt
20
- Requires-Dist: keras
21
15
 
22
16
  [English Document](ENREADME.md)
23
17
 
@@ -29,7 +23,7 @@ Requires-Dist: keras
29
23
  ### 当前支持
30
24
  | 算子类型 | 框架支持 |
31
25
  |----------|----------|
32
- | GPU 算子 | PyTorch、JAX(TensorFlow 待 Google 支持 Triton 后上线) |
26
+ | GPU 算子 | PyTorch、JAX|
33
27
  | 原生算子 | PyTorch、JAX、TensorFlow、NumPy |
34
28
 
35
29
  > 未来若 Keras 生态扩展,可能支持 MLX、OpenVINO。
@@ -43,6 +37,12 @@ Requires-Dist: keras
43
37
  pip install rwkv_ops
44
38
  ```
45
39
 
40
+ 当然pip包对于编译的算子pip uninstal没法删干净,所有可以试着从源码安装
41
+ ```bash
42
+ git clone https://github.com/pass-lin/rwkv_ops.git
43
+ cd rwkv_ops
44
+ bash install.sh
45
+ ```
46
46
  ---
47
47
 
48
48
  ## 环境变量
@@ -51,10 +51,9 @@ pip install rwkv_ops
51
51
  |---|---|---|---|---|
52
52
  | `KERAS_BACKEND` | Keras 后端 | `jax` / `torch` / `tensorflow` / `numpy` | — | 低 |
53
53
  | `KERNEL_BACKEND` | 算子后端 | `jax` / `torch` / `tensorflow` / `numpy` | `torch` | **高** |
54
- | `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | | — |
54
+ | `KERNEL_TYPE` | 实现类型 | `triton` / `cuda` / `native` | `cuda` | — |
55
55
 
56
56
  > 若 `KERNEL_BACKEND` 有值,直接采用;若为空,则用 `KERAS_BACKEND`;两者皆空则默认 `torch`。
57
- > `native` 为原生算子,无 chunkwise,速度慢且显存高。
58
57
 
59
58
  ---
60
59
 
@@ -102,34 +101,36 @@ def generalized_delta_rule(
102
101
  ```python
103
102
  from rwkv_ops import get_generalized_delta_rule
104
103
 
105
- generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
104
+ generalized_delta_rule, USE_TRITON_KERNEL = get_generalized_delta_rule(
106
105
  your_head_size, KERNEL_TYPE="cuda"
107
106
  )
108
107
  ```
109
108
 
110
- - `RWKV7_USE_KERNEL` 为常量,标记是否使用 chunkwise 算子。
109
+ - `USE_TRITON_KERNEL` 为常量,标记是否使用 chunkwise 算子。
111
110
  - 两者 padding 处理逻辑不同:
112
111
 
113
112
  ```python
114
113
  if padding_mask is not None:
115
- if RWKV7_USE_KERNEL:
116
- w += (1 - padding_mask) * -1e9
117
- else:
118
- w = w * padding_mask + 1 - padding_mask
114
+ w += (1 - padding_mask) * -1e9
119
115
  ```
116
+ - 对于上面的代码,基于循环的算子可以针对left pading和right pading都能成功处理。
117
+ - 而如果用的是chunkwise算子,建议统一left padding,如果是cuda或者原生,则都left right都能正确处理
120
118
 
121
- ---
122
119
 
123
120
  ### rwkv7op 实现状态
124
121
 
125
122
  | Framework | cuda | triton | native |
126
123
  |-------------|------|--------|--------|
127
124
  | PyTorch | ✅ | ✅ | ✅ |
128
- | JAX | | ✅ | ✅ |
129
- | TensorFlow | | ❌ | ✅ |
125
+ | JAX | | ✅ | ✅ |
126
+ | TensorFlow | ⚠️ | ❌ | ✅ |
130
127
  | NumPy | ❌ | ❌ | ✅ |
131
128
 
132
129
  ---
130
+ > `native` 为原生算子,无 chunkwise,速度慢且显存高。
131
+ > `triton` 使用的是chunkwise算法实现,速度快,并行度高,缺点是精度很差,介意勿用
132
+ > `cuda` 为基于 CUDA 的原生算子,速度很快,并且kernel内部使用fp32实现,所以精度也很高。缺点就是长序列的时候比较吃亏跑不满。
133
+ > tensorflow的CUDA实现只支持前向计算,是没有梯度的。并且这个是使用jax的cuda实现实现的,你需要保证你能够成功运行jax的cuda kernel。
133
134
 
134
135
  ## rwkv6op 使用方法
135
136
 
@@ -1,14 +1,30 @@
1
- rwkv_ops/__init__.py,sha256=R44XEpC4o1C-bUJy43YokONCD5Jxp-kBp3eyNZNqCIk,843
2
- rwkv_ops/rwkv6_kernel/__init__.py,sha256=_j6G_3fY8xPxrlZbgDT2ndX4IPiNJ4qjqIcdmNI_r9Q,4100
3
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=WOzqfQQSHHMoWqm2kRz_BhtMzGYc5USJ26qaEwuARo4,30117
4
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=otjfw5n6nf2YVpBIWIZjaCsxMyLXXwg-ma1ueXX-EdY,3274
5
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=Q1uPMgaS21OEfQ8-sBDjaCUASMtkSOdN3OosEUsBp9U,12918
6
- rwkv_ops/rwkv7_kernel/__init__.py,sha256=20UIBG-nydjB72iboCJ7TDIwgXVF1aYu93sqApFBqc8,5557
1
+ rwkv_ops/__init__.py,sha256=aQr_HZZ-tB6b_HTjnrbRNyPnExrq0qYReK7-UTqdFTA,855
2
+ rwkv_ops/rwkv6_kernel/__init__.py,sha256=ktIzkK6EUc2nonLQnl2NAjJj9kMt02i9zqfjFcnM_NQ,3647
3
+ rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py,sha256=4SL93Z4mmuQldHtmwqTKcP7M-outTU5Rge2qgDGzwBg,29966
4
+ rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py,sha256=c3ZSJ9xC6-PKr88pOhjmBximdhwmP1_i7UOcIdKB43c,3354
5
+ rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py,sha256=Pv0WsBp5byTSwkYrYkHcJa3wftSsHHzfRzleKdmJayY,12915
6
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
7
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
8
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h,sha256=KYJiWmmig0Wh-zpiWV96J_be8jlyc38Ztd1iqNoqVFI,1501
9
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h,sha256=CMQclcyHaDL65v7dEBOYqNNQcV332fFXmVNe-F23mJo,1526
10
+ rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu,sha256=t6Q8_M63eSlyOqcwYiGfI0HvlvQ_z0okBR4JNKqW5n0,20810
11
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp,sha256=oM13TCQi2GMIf3f-Z39WOL8M_8GmGI_Kdhiq3Y2keJw,1643
12
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h,sha256=epwsW8OUIOvrlNuW3BAmAbgB8n8CKOFEYafBxQy3ptw,2209
13
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h,sha256=4qAa3frGI1buJanudvLT94rycS1bxmRQIA8zSNa0hBI,1501
14
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h,sha256=CMQclcyHaDL65v7dEBOYqNNQcV332fFXmVNe-F23mJo,1526
15
+ rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip,sha256=givSxPA7YfKGz75rOtN8TAjTxWWraVNgTGPZfAJsZsQ,20836
16
+ rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu,sha256=tfRbMQBkl_LT7EVaJ6KoWYcQ902ApCrS6zkjXldFZXY,12770
17
+ rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp,sha256=cyCTiF--4SQiDJu7Dy_NuEhSe1vyki6JS4I2rsvT714,6659
18
+ rwkv_ops/rwkv7_kernel/__init__.py,sha256=1TFr3m5kLC0Ldv0HEnzt58_g84e7KD81R6s8TpmiWmI,8148
7
19
  rwkv_ops/rwkv7_kernel/get_jax_devices_info.py,sha256=cMIaNED7d1PvYNSyq8wNI3G7wNvcgdUj9HWRBLuSVM8,6004
8
20
  rwkv_ops/rwkv7_kernel/get_torch_devices_info.py,sha256=ZL_rAM6lHB4nTOOU28Xm08qptfuIoijOMi_xwJG3KCo,7380
9
- rwkv_ops/rwkv7_kernel/jax_op.py,sha256=l7HMVUqf3M9YmE1OADoin93HpVsCqm4rnrKoH9s6Dzg,9158
10
- rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=QPrXLbqw0chipQg_0jepRp2U19BYpBBFdKZWyaDNNoc,2488
21
+ rwkv_ops/rwkv7_kernel/jax_op.py,sha256=C7jOvJ-ZWTFfCZBQNzMbqgoVHuDS2QCGlBsGEMM4Fn0,9140
22
+ rwkv_ops/rwkv7_kernel/native_keras_op.py,sha256=dCWdzuVZxAKHCBURZqgOLN3n_yKFFNX5uORlbvztH6w,2502
23
+ rwkv_ops/rwkv7_kernel/tf_eager_kernel.py,sha256=2t2uf1iNznYpYFlqt9REY0GwGeycYuaJl-4QFk2rJHc,4357
11
24
  rwkv_ops/rwkv7_kernel/torch_op.py,sha256=jw_AvqshTAG4t9-MRqxFQNi_bTzxNbx3lwnMifPk8-8,14070
25
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt,sha256=Dq4Ea8N2xOEej2jZpEw4MtFjUFgN0PUciejVOCSP-FM,1400
26
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu,sha256=WePveEdUixaQA51hJUK8Sr7Q7jDTstybEWZczdjuGSo,9690
27
+ rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py,sha256=3lvCKIa9DO7MY3aZNyJM0AyHlQUvDKGsnYVr8MLl7Vg,7998
12
28
  rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py,sha256=uHsf_1qrtRK62IvhLuzefHGPWpHXmw1p0tqmwlHcptk,346
13
29
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py,sha256=2Voq1Bdzn0DFloiLvwINBk7akmxRWIqXIQeyafrJJGg,2138
14
30
  rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py,sha256=rhmglqHIIww7yPzaSBEp9ISxhhxoUbMtV51AUDyhUd8,1425
@@ -19,6 +35,8 @@ rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py,sha256=4SjQ_zTZvFxsBMeWOx0JGFg9E
19
35
  rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py,sha256=NoOh2_hA_rdH5bmaNNMAdCgVPfWvQpf-Q8BqF926jrw,667
20
36
  rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py,sha256=PAMtE6wCW2Hz39oiHLGqhxY77csQAMYdNP2najDO_Jg,1407
21
37
  rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py,sha256=8jyRxE8G0Q32MyGR-AsXnyBanWfZRb1WnNEHAVRptVE,1822
38
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu,sha256=uH24A3Z8lItMRc7jq0ybswmwiJGKT3BsAlW18hg_7Gc,5040
39
+ rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp,sha256=Wk5QYvIM9m-YJdSEh6zSzVKaw1v2lphvupbwA0GHmGw,2201
22
40
  rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py,sha256=_u1srIATeoHKlVTVWbWXdpkjaggugl9y-Kx_Y4pYdIY,430
23
41
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py,sha256=CWtotXkVvHz4-rkuOqWh6zKy95jwimS9If6SU45ylW0,2103
24
42
  rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py,sha256=4RJbyUTO23OxwH1rGVxeBiBVZKNHpPL_tJ7MFoDCIts,1475
@@ -40,8 +58,7 @@ rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py,sha256=pRp_z587PrnpgRVpi031IndyjVI
40
58
  rwkv_ops/rwkv7_kernel/triton_kernel/utils.py,sha256=TNGlkwGq4t-TOcdVBk_N_vHPLzMFTu_F0V-O1RprIO4,553
41
59
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py,sha256=szaG11q_WmpyhXi6aVWwzizvflCh5wND8wGA_V8afzA,5479
42
60
  rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py,sha256=jbb19DUTHENU2RIOv_T4m_W1eXMqdRqG0XevIkBOhI4,9438
43
- rwkv_ops-0.2.2.dist-info/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
44
- rwkv_ops-0.2.2.dist-info/METADATA,sha256=bBugt-UogSOxR0YxdSxnoAEPB7wb7oHhTrzlCEY5h_8,8411
45
- rwkv_ops-0.2.2.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
46
- rwkv_ops-0.2.2.dist-info/top_level.txt,sha256=cVqoKE-WR_e2gHL87-6O4K1kG6-yTJGB2huyr6FmD2I,9
47
- rwkv_ops-0.2.2.dist-info/RECORD,,
61
+ rwkv_ops-0.3.1.dist-info/METADATA,sha256=cCo4QhAeyVGbAMMO7FC0PxskVRYPmKU30FUYbrqoVnw,8853
62
+ rwkv_ops-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
63
+ rwkv_ops-0.3.1.dist-info/licenses/LICENSE.txt,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
64
+ rwkv_ops-0.3.1.dist-info/RECORD,,
@@ -1,5 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
-
@@ -1 +0,0 @@
1
- rwkv_ops