rwkv-ops 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (31) hide show
  1. rwkv_ops/__init__.py +5 -6
  2. rwkv_ops/rwkv6_kernel/__init__.py +0 -6
  3. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  4. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  5. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  6. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  7. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  8. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  9. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  10. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  11. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  12. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  13. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +21 -23
  14. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +14 -10
  15. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  16. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  17. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +4 -4
  18. rwkv_ops/rwkv7_kernel/__init__.py +77 -29
  19. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  20. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +279 -0
  21. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +237 -0
  22. rwkv_ops/rwkv7_kernel/jax_op.py +6 -5
  23. rwkv_ops/rwkv7_kernel/native_keras_op.py +5 -6
  24. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +123 -0
  25. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +165 -0
  26. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +35 -0
  27. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/METADATA +28 -27
  28. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/RECORD +30 -13
  29. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info}/WHEEL +1 -2
  30. rwkv_ops-0.2.2.dist-info/top_level.txt +0 -1
  31. {rwkv_ops-0.2.2.dist-info → rwkv_ops-0.3.0.dist-info/licenses}/LICENSE.txt +0 -0
rwkv_ops/__init__.py CHANGED
@@ -1,17 +1,16 @@
1
- __version__ = "0.2.2"
1
+ __version__ = "0.3.0"
2
2
  import os
3
3
 
4
- KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "triton")
4
+ KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "cuda").lower()
5
5
  KERAS_BACKEND = os.environ.get("KERAS_BACKEND")
6
6
  BACKEND = os.environ.get("KERNEL_BACKEND")
7
7
 
8
8
 
9
9
  if KERAS_BACKEND is not None:
10
- BACKEND = KERAS_BACKEND
10
+ BACKEND = KERAS_BACKEND.lower()
11
11
  elif BACKEND is not None:
12
- os.environ["KERAS_BACKEND"] = BACKEND
12
+ os.environ["KERAS_BACKEND"] = BACKEND.lower()
13
13
  else:
14
- import torch
15
14
  import keras
16
15
 
17
16
  BACKEND = "torch"
@@ -22,7 +21,7 @@ assert BACKEND in ["torch", "jax", "numpy", "tensorflow"]
22
21
  from .rwkv7_kernel import get_generalized_delta_rule
23
22
  from .rwkv6_kernel import get_rwkv6_kernel
24
23
 
25
- generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
24
+ generalized_delta_rule, RWKV7_USE_TRITON_KERNEL = get_generalized_delta_rule(
26
25
  KERNEL_TYPE=KERNEL_TYPE
27
26
  )
28
27
  rwkv7_op = generalized_delta_rule
@@ -17,12 +17,6 @@ def get_rwkv6_kernel(KERNEL_TYPE="native"):
17
17
 
18
18
  ops_kernel = False
19
19
  else:
20
- print(
21
- "The CUDA kernel of RWKV6 cannot be used in JAX > = 0.6 version. If you need to use the CUDA KERNEL of RWKV6 version, please downgrade the version. It is recommended to downgrade to 0.4.34"
22
- )
23
- print(
24
- "RWKV6的CUDA kernel在JAX> = 0.6版本无法使用,如果需要使用RWKV6版本的CUDA KERNEL,请降级版本,建议降级到0.4.34"
25
- )
26
20
  CudaOperator = None
27
21
  elif keras.config.backend() == "torch":
28
22
  from .torch_rwkv_kernel import RWKVKernelOperator as CudaOperator
@@ -0,0 +1,44 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #include "kernels.h"
17
+ #include "pybind11_kernel_helpers.h"
18
+
19
+ namespace {
20
+ pybind11::dict WKVRegistrations() {
21
+ pybind11::dict dict;
22
+ dict["wkv_forward"] =
23
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_fn);
24
+ dict["wkv_backward"] =
25
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_backward_fn);
26
+ dict["wkv_forward_with_state"] =
27
+ gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_with_state_fn);
28
+ return dict;
29
+ }
30
+
31
+ PYBIND11_MODULE(gpu_ops, m) {
32
+ m.def("get_rwkv_registrations", &WKVRegistrations);
33
+ m.def("create_rwkv_descriptor",
34
+ [](int B, int T,int C, int H,bool S, gpu_ops::ElementType input_type,gpu_ops::ElementType output_type) {
35
+ return gpu_ops::PackDescriptor(gpu_ops::WKVDescriptor{B, T, C, H, S, input_type, output_type});
36
+ });
37
+
38
+ pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
39
+ .value("BF16", gpu_ops::ElementType::BF16)
40
+ .value("F16", gpu_ops::ElementType::F16)
41
+ .value("F32", gpu_ops::ElementType::F32);
42
+
43
+ }
44
+ } // namespace
@@ -0,0 +1,64 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ // This header is not specific to our application and you'll probably want
17
+ // something like this for any extension you're building. This includes the
18
+ // infrastructure needed to serialize descriptors that are used with the
19
+ // "opaque" parameter of the GPU custom call. In our example we'll use this
20
+ // parameter to pass the size of our problem.
21
+
22
+ #ifndef _GPU_OPS_KERNEL_HELPERS_H_
23
+ #define _GPU_OPS_KERNEL_HELPERS_H_
24
+
25
+ #include <cstdint>
26
+ #include <stdexcept>
27
+ #include <string>
28
+ #include <type_traits>
29
+
30
+ #define JAX_APEX_WARP_SIZE 32
31
+
32
+ namespace gpu_ops {
33
+
34
+ // https://en.cppreference.com/w/cpp/numeric/bit_cast
35
+ template <class To, class From>
36
+ typename std::enable_if<sizeof(To) == sizeof(From) &&
37
+ std::is_trivially_copyable<From>::value &&
38
+ std::is_trivially_copyable<To>::value,
39
+ To>::type
40
+ bit_cast(const From &src) noexcept {
41
+ static_assert(std::is_trivially_constructible<To>::value,
42
+ "This implementation additionally requires destination type to "
43
+ "be trivially constructible");
44
+
45
+ To dst;
46
+ memcpy(&dst, &src, sizeof(To));
47
+ return dst;
48
+ }
49
+
50
+ template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
51
+ return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
52
+ }
53
+
54
+ template <typename T>
55
+ const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
56
+ if (opaque_len != sizeof(T)) {
57
+ throw std::runtime_error("Invalid opaque object size");
58
+ }
59
+ return bit_cast<const T *>(opaque);
60
+ }
61
+
62
+ } // namespace gpu_ops
63
+
64
+ #endif
@@ -0,0 +1,56 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ #ifndef _GPU_OPS_KERNELS_H_
17
+ #define _GPU_OPS_KERNELS_H_
18
+
19
+ #include <cuda_runtime_api.h>
20
+
21
+ #include <cstddef>
22
+ #include <cstdint>
23
+
24
+ #ifndef _N_
25
+ #define _N_ 8
26
+ #endif
27
+ #ifndef _T_
28
+ #define _T_ 16
29
+ #endif
30
+ namespace gpu_ops {
31
+
32
+ enum ElementType { BF16, F16, F32 };
33
+
34
+ struct WKVDescriptor {
35
+ int B;
36
+ int T;
37
+ int C;
38
+ int H;
39
+ bool S;
40
+ ElementType x_type;
41
+ ElementType y_type;
42
+ };
43
+
44
+ void rwkv_forward_fn(cudaStream_t stream, void **buffers,
45
+ const char *opaque,
46
+ std::size_t opaque_len);
47
+ void rwkv_backward_fn(cudaStream_t stream, void **buffers,
48
+ const char *opaque,
49
+ std::size_t opaque_len);
50
+
51
+ void rwkv_forward_with_state_fn(cudaStream_t stream, void **buffers,
52
+ const char *opaque,
53
+ std::size_t opaque_len);
54
+ } // namespace gpu_ops
55
+
56
+ #endif
@@ -0,0 +1,41 @@
1
+ /* Copyright 2024 The JAX Authors.
2
+
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ ==============================================================================*/
15
+
16
+ // This header extends kernel_helpers.h with the pybind11 specific interface to
17
+ // serializing descriptors. It also adds a pybind11 function for wrapping our
18
+ // custom calls in a Python capsule. This is separate from kernel_helpers so
19
+ // that the CUDA code itself doesn't include pybind11. I don't think that this
20
+ // is strictly necessary, but they do it in jaxlib, so let's do it here too.
21
+
22
+ #ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
23
+ #define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
24
+
25
+ #include <pybind11/pybind11.h>
26
+
27
+ #include "kernel_helpers.h"
28
+
29
+ namespace gpu_ops {
30
+
31
+ template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
32
+ return pybind11::bytes(PackDescriptorAsString(descriptor));
33
+ }
34
+
35
+ template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
36
+ return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
37
+ }
38
+
39
+ } // namespace gpu_ops
40
+
41
+ #endif