rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# copy right from https://github.com/infiy-quine/RWKV6_Keras_Operator
|
|
2
|
+
import os
|
|
3
|
+
import keras
|
|
4
|
+
from keras import ops
|
|
5
|
+
from distutils.util import strtobool
|
|
6
|
+
from packaging import version
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_rwkv6_kernel(KERNEL_TYPE="native"):
|
|
10
|
+
ops_kernel = True
|
|
11
|
+
if KERNEL_TYPE == "cuda":
|
|
12
|
+
if keras.config.backend() == "jax":
|
|
13
|
+
import jax
|
|
14
|
+
|
|
15
|
+
if version.parse(jax.__version__) < version.parse("0.6.0"):
|
|
16
|
+
from .jax_rwkv_kernel import RWKVKernelOperator as CudaOperator
|
|
17
|
+
|
|
18
|
+
ops_kernel = False
|
|
19
|
+
else:
|
|
20
|
+
CudaOperator = None
|
|
21
|
+
elif keras.config.backend() == "torch":
|
|
22
|
+
from .torch_rwkv_kernel import RWKVKernelOperator as CudaOperator
|
|
23
|
+
|
|
24
|
+
ops_kernel = False
|
|
25
|
+
else:
|
|
26
|
+
CudaOperator = None
|
|
27
|
+
else:
|
|
28
|
+
CudaOperator = None
|
|
29
|
+
from .ops_rwkv_kernel import RWKVKernelOperator as OpsOperator
|
|
30
|
+
|
|
31
|
+
class RWKVKernelOperator:
|
|
32
|
+
def __init__(self, head_size, max_sequence_length, ops_loop=False):
|
|
33
|
+
self.enbale_cuda = CudaOperator is not None
|
|
34
|
+
|
|
35
|
+
if self.enbale_cuda:
|
|
36
|
+
self.cuda_operator = CudaOperator(head_size, max_sequence_length)
|
|
37
|
+
|
|
38
|
+
self.ops_operator = OpsOperator(head_size, max_sequence_length)
|
|
39
|
+
|
|
40
|
+
self.ops_loop = ops_loop
|
|
41
|
+
|
|
42
|
+
def __call__(
|
|
43
|
+
self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
|
|
44
|
+
):
|
|
45
|
+
seq_len = r.shape[1]
|
|
46
|
+
|
|
47
|
+
def call_parallel():
|
|
48
|
+
if self.enbale_cuda:
|
|
49
|
+
return self.cuda_operator(
|
|
50
|
+
r=r,
|
|
51
|
+
k=k,
|
|
52
|
+
v=v,
|
|
53
|
+
w=w,
|
|
54
|
+
u=u,
|
|
55
|
+
with_state=with_state,
|
|
56
|
+
init_state=init_state,
|
|
57
|
+
state_map=state_map,
|
|
58
|
+
)
|
|
59
|
+
else:
|
|
60
|
+
return self.ops_operator(
|
|
61
|
+
r=r,
|
|
62
|
+
k=k,
|
|
63
|
+
v=v,
|
|
64
|
+
w=w,
|
|
65
|
+
u=u,
|
|
66
|
+
with_state=with_state,
|
|
67
|
+
init_state=init_state,
|
|
68
|
+
state_map=state_map,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def call_one_step():
|
|
72
|
+
return self.ops_operator(
|
|
73
|
+
r=r,
|
|
74
|
+
k=k,
|
|
75
|
+
v=v,
|
|
76
|
+
w=w,
|
|
77
|
+
u=u,
|
|
78
|
+
with_state=with_state,
|
|
79
|
+
init_state=init_state,
|
|
80
|
+
state_map=state_map,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if not self.ops_loop:
|
|
84
|
+
return ops.cond(
|
|
85
|
+
seq_len != 1 and not ops_kernel, call_parallel, call_one_step
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
return call_parallel()
|
|
89
|
+
|
|
90
|
+
return RWKVKernelOperator
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
# from .ops_rwkv_kernal import RWKVKernelOperator as OPSKernelOperator
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
"""
|
|
97
|
+
新增三个参数
|
|
98
|
+
return_state 布尔类型 是否返回最终的state,如果想自定义init_state也需要启用这个开关
|
|
99
|
+
|
|
100
|
+
init_state
|
|
101
|
+
当init_state省缺时,则使用全零初始化BatchSize维度上的状态。
|
|
102
|
+
形状: (state_kinds,num_heads,head_size, head_size), 其中state_kinds为小于等于Batch_Size的正整数
|
|
103
|
+
精度: 在r为fp16时 init_state为fp32 其余时候类型与r相同
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
state_map
|
|
107
|
+
形状: (Batch_Size,)
|
|
108
|
+
精度: int64, list[int]
|
|
109
|
+
这个数组定义了state到r上每个Batch维度切片间的映射关系
|
|
110
|
+
取值范围: [0, state_kinds)
|
|
111
|
+
|
|
112
|
+
返回:
|
|
113
|
+
output, output_state
|
|
114
|
+
|
|
115
|
+
def __call__(self,r, k, v, w, u, return_state=False, init_state=None, state_map=None):
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
"""
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
#include "kernels.h"
|
|
17
|
+
#include "pybind11_kernel_helpers.h"
|
|
18
|
+
|
|
19
|
+
namespace {
|
|
20
|
+
pybind11::dict WKVRegistrations() {
|
|
21
|
+
pybind11::dict dict;
|
|
22
|
+
dict["wkv_forward"] =
|
|
23
|
+
gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_fn);
|
|
24
|
+
dict["wkv_backward"] =
|
|
25
|
+
gpu_ops::EncapsulateFunction(gpu_ops::rwkv_backward_fn);
|
|
26
|
+
dict["wkv_forward_with_state"] =
|
|
27
|
+
gpu_ops::EncapsulateFunction(gpu_ops::rwkv_forward_with_state_fn);
|
|
28
|
+
return dict;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
PYBIND11_MODULE(gpu_ops, m) {
|
|
32
|
+
m.def("get_rwkv_registrations", &WKVRegistrations);
|
|
33
|
+
m.def("create_rwkv_descriptor",
|
|
34
|
+
[](int B, int T,int C, int H,bool S, gpu_ops::ElementType input_type,gpu_ops::ElementType output_type) {
|
|
35
|
+
return gpu_ops::PackDescriptor(gpu_ops::WKVDescriptor{B, T, C, H, S, input_type, output_type});
|
|
36
|
+
});
|
|
37
|
+
|
|
38
|
+
pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
|
|
39
|
+
.value("BF16", gpu_ops::ElementType::BF16)
|
|
40
|
+
.value("F16", gpu_ops::ElementType::F16)
|
|
41
|
+
.value("F32", gpu_ops::ElementType::F32);
|
|
42
|
+
|
|
43
|
+
}
|
|
44
|
+
} // namespace
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
// This header is not specific to our application and you'll probably want
|
|
17
|
+
// something like this for any extension you're building. This includes the
|
|
18
|
+
// infrastructure needed to serialize descriptors that are used with the
|
|
19
|
+
// "opaque" parameter of the GPU custom call. In our example we'll use this
|
|
20
|
+
// parameter to pass the size of our problem.
|
|
21
|
+
|
|
22
|
+
#ifndef _GPU_OPS_KERNEL_HELPERS_H_
|
|
23
|
+
#define _GPU_OPS_KERNEL_HELPERS_H_
|
|
24
|
+
|
|
25
|
+
#include <cstdint>
|
|
26
|
+
#include <stdexcept>
|
|
27
|
+
#include <string>
|
|
28
|
+
#include <type_traits>
|
|
29
|
+
|
|
30
|
+
#define JAX_APEX_WARP_SIZE 32
|
|
31
|
+
|
|
32
|
+
namespace gpu_ops {
|
|
33
|
+
|
|
34
|
+
// https://en.cppreference.com/w/cpp/numeric/bit_cast
|
|
35
|
+
template <class To, class From>
|
|
36
|
+
typename std::enable_if<sizeof(To) == sizeof(From) &&
|
|
37
|
+
std::is_trivially_copyable<From>::value &&
|
|
38
|
+
std::is_trivially_copyable<To>::value,
|
|
39
|
+
To>::type
|
|
40
|
+
bit_cast(const From &src) noexcept {
|
|
41
|
+
static_assert(std::is_trivially_constructible<To>::value,
|
|
42
|
+
"This implementation additionally requires destination type to "
|
|
43
|
+
"be trivially constructible");
|
|
44
|
+
|
|
45
|
+
To dst;
|
|
46
|
+
memcpy(&dst, &src, sizeof(To));
|
|
47
|
+
return dst;
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
template <typename T> std::string PackDescriptorAsString(const T &descriptor) {
|
|
51
|
+
return std::string(bit_cast<const char *>(&descriptor), sizeof(T));
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
template <typename T>
|
|
55
|
+
const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
|
|
56
|
+
if (opaque_len != sizeof(T)) {
|
|
57
|
+
throw std::runtime_error("Invalid opaque object size");
|
|
58
|
+
}
|
|
59
|
+
return bit_cast<const T *>(opaque);
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
} // namespace gpu_ops
|
|
63
|
+
|
|
64
|
+
#endif
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
#ifndef _GPU_OPS_KERNELS_H_
|
|
17
|
+
#define _GPU_OPS_KERNELS_H_
|
|
18
|
+
|
|
19
|
+
#include <cuda_runtime_api.h>
|
|
20
|
+
|
|
21
|
+
#include <cstddef>
|
|
22
|
+
#include <cstdint>
|
|
23
|
+
|
|
24
|
+
#ifndef _N_
|
|
25
|
+
#define _N_ 8
|
|
26
|
+
#endif
|
|
27
|
+
#ifndef _T_
|
|
28
|
+
#define _T_ 16
|
|
29
|
+
#endif
|
|
30
|
+
namespace gpu_ops {
|
|
31
|
+
|
|
32
|
+
enum ElementType { BF16, F16, F32 };
|
|
33
|
+
|
|
34
|
+
struct WKVDescriptor {
|
|
35
|
+
int B;
|
|
36
|
+
int T;
|
|
37
|
+
int C;
|
|
38
|
+
int H;
|
|
39
|
+
bool S;
|
|
40
|
+
ElementType x_type;
|
|
41
|
+
ElementType y_type;
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
void rwkv_forward_fn(cudaStream_t stream, void **buffers,
|
|
45
|
+
const char *opaque,
|
|
46
|
+
std::size_t opaque_len);
|
|
47
|
+
void rwkv_backward_fn(cudaStream_t stream, void **buffers,
|
|
48
|
+
const char *opaque,
|
|
49
|
+
std::size_t opaque_len);
|
|
50
|
+
|
|
51
|
+
void rwkv_forward_with_state_fn(cudaStream_t stream, void **buffers,
|
|
52
|
+
const char *opaque,
|
|
53
|
+
std::size_t opaque_len);
|
|
54
|
+
} // namespace gpu_ops
|
|
55
|
+
|
|
56
|
+
#endif
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
/* Copyright 2024 The JAX Authors.
|
|
2
|
+
|
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
you may not use this file except in compliance with the License.
|
|
5
|
+
You may obtain a copy of the License at
|
|
6
|
+
|
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
See the License for the specific language governing permissions and
|
|
13
|
+
limitations under the License.
|
|
14
|
+
==============================================================================*/
|
|
15
|
+
|
|
16
|
+
// This header extends kernel_helpers.h with the pybind11 specific interface to
|
|
17
|
+
// serializing descriptors. It also adds a pybind11 function for wrapping our
|
|
18
|
+
// custom calls in a Python capsule. This is separate from kernel_helpers so
|
|
19
|
+
// that the CUDA code itself doesn't include pybind11. I don't think that this
|
|
20
|
+
// is strictly necessary, but they do it in jaxlib, so let's do it here too.
|
|
21
|
+
|
|
22
|
+
#ifndef _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
|
23
|
+
#define _GPU_OPS_PYBIND11_KERNEL_HELPERS_H_
|
|
24
|
+
|
|
25
|
+
#include <pybind11/pybind11.h>
|
|
26
|
+
|
|
27
|
+
#include "kernel_helpers.h"
|
|
28
|
+
|
|
29
|
+
namespace gpu_ops {
|
|
30
|
+
|
|
31
|
+
template <typename T> pybind11::bytes PackDescriptor(const T &descriptor) {
|
|
32
|
+
return pybind11::bytes(PackDescriptorAsString(descriptor));
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
|
|
36
|
+
return pybind11::capsule(bit_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
} // namespace gpu_ops
|
|
40
|
+
|
|
41
|
+
#endif
|