rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,120 @@
1
+ # copy right from https://github.com/infiy-quine/RWKV6_Keras_Operator
2
+ import os
3
+ import keras
4
+ from keras import ops
5
+ from distutils.util import strtobool
6
+ from packaging import version
7
+
8
+
9
+ def get_rwkv6_kernel(KERNEL_TYPE="native"):
10
+ ops_kernel = True
11
+ if KERNEL_TYPE == "cuda":
12
+ if keras.config.backend() == "jax":
13
+ import jax
14
+
15
+ if version.parse(jax.__version__) < version.parse("0.6.0"):
16
+ from .jax_rwkv_kernel import RWKVKernelOperator as CudaOperator
17
+
18
+ ops_kernel = False
19
+ else:
20
+ CudaOperator = None
21
+ elif keras.config.backend() == "torch":
22
+ from .torch_rwkv_kernel import RWKVKernelOperator as CudaOperator
23
+
24
+ ops_kernel = False
25
+ else:
26
+ CudaOperator = None
27
+ else:
28
+ CudaOperator = None
29
+ from .ops_rwkv_kernel import RWKVKernelOperator as OpsOperator
30
+
31
+ class RWKVKernelOperator:
32
+ def __init__(self, head_size, max_sequence_length, ops_loop=False):
33
+ self.enbale_cuda = CudaOperator is not None
34
+
35
+ if self.enbale_cuda:
36
+ self.cuda_operator = CudaOperator(head_size, max_sequence_length)
37
+
38
+ self.ops_operator = OpsOperator(head_size, max_sequence_length)
39
+
40
+ self.ops_loop = ops_loop
41
+
42
+ def __call__(
43
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
44
+ ):
45
+ seq_len = r.shape[1]
46
+
47
+ def call_parallel():
48
+ if self.enbale_cuda:
49
+ return self.cuda_operator(
50
+ r=r,
51
+ k=k,
52
+ v=v,
53
+ w=w,
54
+ u=u,
55
+ with_state=with_state,
56
+ init_state=init_state,
57
+ state_map=state_map,
58
+ )
59
+ else:
60
+ return self.ops_operator(
61
+ r=r,
62
+ k=k,
63
+ v=v,
64
+ w=w,
65
+ u=u,
66
+ with_state=with_state,
67
+ init_state=init_state,
68
+ state_map=state_map,
69
+ )
70
+
71
+ def call_one_step():
72
+ return self.ops_operator(
73
+ r=r,
74
+ k=k,
75
+ v=v,
76
+ w=w,
77
+ u=u,
78
+ with_state=with_state,
79
+ init_state=init_state,
80
+ state_map=state_map,
81
+ )
82
+
83
+ if not self.ops_loop:
84
+ return ops.cond(
85
+ seq_len != 1 and not ops_kernel, call_parallel, call_one_step
86
+ )
87
+ else:
88
+ return call_parallel()
89
+
90
+ return RWKVKernelOperator
91
+
92
+
93
+ # from .ops_rwkv_kernal import RWKVKernelOperator as OPSKernelOperator
94
+
95
+
96
+ """
97
+ 新增三个参数
98
+ return_state 布尔类型 是否返回最终的state,如果想自定义init_state也需要启用这个开关
99
+
100
+ init_state
101
+ 当init_state省缺时,则使用全零初始化BatchSize维度上的状态。
102
+ 形状: (state_kinds,num_heads,head_size, head_size), 其中state_kinds为小于等于Batch_Size的正整数
103
+ 精度: 在r为fp16时 init_state为fp32 其余时候类型与r相同
104
+
105
+
106
+ state_map
107
+ 形状: (Batch_Size,)
108
+ 精度: int64, list[int]
109
+ 这个数组定义了state到r上每个Batch维度切片间的映射关系
110
+ 取值范围: [0, state_kinds)
111
+
112
+ 返回:
113
+ output, output_state
114
+
115
+ def __call__(self,r, k, v, w, u, return_state=False, init_state=None, state_map=None):
116
+
117
+
118
+
119
+
120
+ """
@@ -0,0 +1,44 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "kernels.h"
17
+ #include "pybind11_kernel_helpers.h"
18
+
19
+ namespace {
20
+ pybind11::dict WKVRegistrations() {
21
+ pybind11::dict dict;
22
+ dict["wkv_forward"] =
23
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_fn);
24
+ dict["wkv_backward"] =
25
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_backward_fn);
26
+ dict["wkv_forward_with_state"] =
27
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_with_state_fn);
28
+ return dict;
29
+ }
30
+
31
+ PYBIND11_MODULE(gpu_ops, m) {
32
+ m.def("get_rwkv_registrations", &WKVRegistrations);
33
+ m.def("create_rwkv_descriptor",
34
+ [](int B, int T,int C, int H,bool S, gpu_ops::ElementType input_type,gpu_ops::ElementType output_type) {
35
+ return gpu_ops::PackDescriptor(gpu_ops::WKVDescriptor{B, T, C, H, S, input_type, output_type});
36
+ });
37
+
38
+ pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
39
+ .value("BF16", gpu_ops::ElementType::BF16)
40
+ .value("F16", gpu_ops::ElementType::F16)
41
+ .value("F32", gpu_ops::ElementType::F32);
42
+
43
+ }
44
+ } // namespace
@@ -0,0 +1,64 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ // This header is not specific to our application and you'll probably want
17
+ // something like this for any extension you're building. This includes the
18
+ // infrastructure needed to serialize descriptors that are used with the
19
+ // "opaque" parameter of the GPU custom call. In our example we'll use this
20
+ // parameter to pass the size of our problem.
21
+
22
+ #ifndef _GPU_OPS_KERNEL_HELPERS_H_
23
+ #define _GPU_OPS_KERNEL_HELPERS_H_
24
+
25
+ #include <cstdint>
26
+ #include <stdexcept>
27
+ #include <string>
28
+ #include <type_traits>
29
+
30
+ #define JAX_APEX_WARP_SIZE 32
31
+
32
+ namespace gpu_ops {
33
+
34
+ // https://en.cppreference.com/w/cpp/numeric/bit_cast
35
+ template <class To, class From>
36
+ typename std::enable_if<sizeof(To) == sizeof(From) &&
37
+ std::is_trivially_copyable<From>::value &&
38
+ std::is_trivially_copyable<To>::value,
39
+ To>::type
40
+ bit_cast(const From &src) noexcept {
41
+ static_assert(std::is_trivially_constructible<To>::value,
42
+ "This implementation additionally requires destination type to "
43
+ "be trivially constructible");
44
+
45
+ To dst;
46
+ memcpy(&dst, &src, sizeof(To));
47
+ return dst;
48
+ }
49
+
50
+ template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
51
+ return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
52
+ }
53
+
54
+ template <typename T>
55
+ const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
56
+ if (opaque_len != sizeof(T)) {
57
+ throw std::runtime_error("Invalid opaque object size");
58
+ }
59
+ return bit_cast<const T *>(opaque);
60
+ }
61
+
62
+ } // namespace gpu_ops
63
+
64
+ #endif
@@ -0,0 +1,56 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef _GPU_OPS_KERNELS_H_
17
+ #define _GPU_OPS_KERNELS_H_
18
+
19
+ #include <cuda_runtime_api.h>
20
+
21
+ #include <cstddef>
22
+ #include <cstdint>
23
+
24
+ #ifndef _N_
25
+ #define _N_ 8
26
+ #endif
27
+ #ifndef _T_
28
+ #define _T_ 16
29
+ #endif
30
+ namespace gpu_ops {
31
+
32
+ enum ElementType { BF16, F16, F32 };
33
+
34
+ struct WKVDescriptor {
35
+ int B;
36
+ int T;
37
+ int C;
38
+ int H;
39
+ bool S;
40
+ ElementType x_type;
41
+ ElementType y_type;
42
+ };
43
+
44
+ void rwkv_forward_fn(cudaStream_t stream, void **buffers,
45
+ const char *opaque,
46
+ std::size_t opaque_len);
47
+ void rwkv_backward_fn(cudaStream_t stream, void **buffers,
48
+ const char *opaque,
49
+ std::size_t opaque_len);
50
+
51
+ void rwkv_forward_with_state_fn(cudaStream_t stream, void **buffers,
52
+ const char *opaque,
53
+ std::size_t opaque_len);
54
+ } // namespace gpu_ops
55
+
56
+ #endif
@@ -0,0 +1,41 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ // This header extends kernel_helpers.h with the pybind11 specific interface to
17
+ // serializing descriptors. It also adds a pybind11 function for wrapping our
18
+ // custom calls in a Python capsule. This is separate from kernel_helpers so
19
+ // that the CUDA code itself doesn't include pybind11. I don't think that this
20
+ // is strictly necessary, but they do it in jaxlib, so let's do it here too.
21
+
22
+ #ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
23
+ #define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
24
+
25
+ #include <pybind11/pybind11.h>
26
+
27
+ #include "kernel_helpers.h"
28
+
29
+ namespace gpu_ops {
30
+
31
+ template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
32
+ return pybind11::bytes(PackDescriptorAsString(descriptor));
33
+ }
34
+
35
+ template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
36
+ return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
37
+ }
38
+
39
+ } // namespace gpu_ops
40
+
41
+ #endif