rwkv-ops 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rwkv_ops/__init__.py +45 -0
- rwkv_ops/mhc_kernel/__init__.py +50 -0
- rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
- rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
- rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
- rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
- rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
- rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
- rwkv_ops/rwkv6_kernel/__init__.py +120 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
- rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
- rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
- rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
- rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
- rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
- rwkv_ops/rwkv7_kernel/__init__.py +113 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
- rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
- rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
- rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.6.1.dist-info/METADATA +495 -0
- rwkv_ops-0.6.1.dist-info/RECORD +89 -0
- rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
- rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import triton
|
|
7
|
+
from packaging import version
|
|
8
|
+
import torch
|
|
9
|
+
from enum import Enum
|
|
10
|
+
import contextlib
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@lru_cache(maxsize=None)
|
|
14
|
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
|
15
|
+
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
|
16
|
+
"multiprocessor_count"
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@lru_cache(maxsize=None)
|
|
21
|
+
def get_available_device() -> str:
|
|
22
|
+
try:
|
|
23
|
+
return triton.runtime.driver.active.get_current_target().backend
|
|
24
|
+
except BaseException:
|
|
25
|
+
import warnings
|
|
26
|
+
|
|
27
|
+
warnings.warn(
|
|
28
|
+
("Triton is not supported on current platform, roll back to CPU."),
|
|
29
|
+
stacklevel=1,
|
|
30
|
+
)
|
|
31
|
+
return "cpu"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@lru_cache(maxsize=None)
|
|
35
|
+
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
|
36
|
+
device = get_available_device()
|
|
37
|
+
if device == "cuda":
|
|
38
|
+
return "nvidia"
|
|
39
|
+
elif device == "hip":
|
|
40
|
+
return "amd"
|
|
41
|
+
elif device == "xpu":
|
|
42
|
+
return "intel"
|
|
43
|
+
else:
|
|
44
|
+
return device
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
|
48
|
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
|
49
|
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
|
50
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
51
|
+
|
|
52
|
+
device_platform = _check_platform()
|
|
53
|
+
|
|
54
|
+
is_intel = device_platform == "intel"
|
|
55
|
+
is_nvidia = device_platform == "nvidia"
|
|
56
|
+
is_amd = device_platform == "amd"
|
|
57
|
+
|
|
58
|
+
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@lru_cache(maxsize=None)
|
|
62
|
+
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
|
63
|
+
return version.parse(torch.__version__) >= version.parse(version_s)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
is_intel_a770 = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
|
67
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
68
|
+
device_torch_lib = getattr(torch, device)
|
|
69
|
+
if check_pytorch_version("2.4"):
|
|
70
|
+
device = "cuda" if device == "cpu" else device
|
|
71
|
+
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
|
72
|
+
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
|
73
|
+
|
|
74
|
+
def custom_device_ctx(index: int):
|
|
75
|
+
return device_torch_lib.device(index)
|
|
76
|
+
else:
|
|
77
|
+
assert device == "cuda", (
|
|
78
|
+
"Only cuda device is supported for PyTorch version < 2.4.0."
|
|
79
|
+
)
|
|
80
|
+
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
|
81
|
+
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
|
82
|
+
|
|
83
|
+
def custom_device_ctx(index: int):
|
|
84
|
+
return torch.cuda.device(index)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# Nvidia Ampere or newer, haven't check AMD and intel yet.
|
|
88
|
+
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_all_max_shared_memory():
|
|
92
|
+
return [
|
|
93
|
+
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
|
|
94
|
+
for i in range(device_torch_lib.device_count())
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
device_shared_mem_list = get_all_max_shared_memory()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@lru_cache(maxsize=None)
|
|
102
|
+
def is_triton_shared_mem_enough(
|
|
103
|
+
max_shared_mem: int = 102400, tensor_idx: int = 0
|
|
104
|
+
) -> bool:
|
|
105
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
106
|
+
return max_shared_memory >= max_shared_mem
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
device_capacity = is_triton_shared_mem_enough()
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _cpu_device_warning():
|
|
113
|
+
import warnings
|
|
114
|
+
|
|
115
|
+
warnings.warn(
|
|
116
|
+
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_all_max_shared_mem():
|
|
121
|
+
try:
|
|
122
|
+
return [
|
|
123
|
+
triton.runtime.driver.active.utils.get_device_properties(i)[
|
|
124
|
+
"max_shared_mem"
|
|
125
|
+
]
|
|
126
|
+
for i in range(device_torch_lib.device_count())
|
|
127
|
+
]
|
|
128
|
+
except BaseException:
|
|
129
|
+
_cpu_device_warning()
|
|
130
|
+
return [-1]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Backend(Enum):
|
|
134
|
+
ADA = 101376 # RTX 4090
|
|
135
|
+
AMPERE = 166912 # A100
|
|
136
|
+
HOPPER = 232448 # H100
|
|
137
|
+
DEFAULT = 102400 # Default
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def get_shared_memory(cls, arch: str) -> int:
|
|
141
|
+
try:
|
|
142
|
+
return cls[arch.upper()].value
|
|
143
|
+
except KeyError:
|
|
144
|
+
return cls.DEFAULT.value
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@lru_cache(maxsize=None)
|
|
148
|
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
|
149
|
+
try:
|
|
150
|
+
device_shared_mem_list = get_all_max_shared_mem()
|
|
151
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
152
|
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
|
153
|
+
except Exception:
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def tensor_cache(fn):
|
|
158
|
+
"""
|
|
159
|
+
A decorator that caches the most recent result of a function with tensor inputs.
|
|
160
|
+
|
|
161
|
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
|
162
|
+
If the function is called again with the same input tensors, it will return the cached result.
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
fn (Callable[..., torch.Tensor]):
|
|
167
|
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Callable[..., torch.Tensor]:
|
|
171
|
+
A wrapped version of the input function with single-entry caching.
|
|
172
|
+
"""
|
|
173
|
+
last_args = None
|
|
174
|
+
last_kwargs = None
|
|
175
|
+
last_result = None
|
|
176
|
+
|
|
177
|
+
@functools.wraps(fn)
|
|
178
|
+
def wrapper(*args, **kwargs):
|
|
179
|
+
nonlocal last_args, last_kwargs, last_result
|
|
180
|
+
|
|
181
|
+
if last_args is not None and last_kwargs is not None:
|
|
182
|
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
|
183
|
+
if all(a is b for a, b in zip(args, last_args)) and all(
|
|
184
|
+
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
|
185
|
+
):
|
|
186
|
+
return last_result
|
|
187
|
+
|
|
188
|
+
result = fn(*args, **kwargs)
|
|
189
|
+
last_args, last_kwargs, last_result = args, kwargs, result
|
|
190
|
+
return result
|
|
191
|
+
|
|
192
|
+
return wrapper
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@tensor_cache
|
|
196
|
+
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
|
197
|
+
return cu_seqlens[1:] - cu_seqlens[:-1]
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@tensor_cache
|
|
201
|
+
def prepare_chunk_indices(
|
|
202
|
+
cu_seqlens: torch.LongTensor, chunk_size: int
|
|
203
|
+
) -> torch.LongTensor:
|
|
204
|
+
indices = torch.cat(
|
|
205
|
+
[
|
|
206
|
+
torch.arange(n)
|
|
207
|
+
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
|
208
|
+
]
|
|
209
|
+
)
|
|
210
|
+
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def input_guard(fn):
|
|
214
|
+
"""
|
|
215
|
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
@functools.wraps(fn)
|
|
219
|
+
def wrapper(*args, **kwargs):
|
|
220
|
+
contiguous_args = (
|
|
221
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
|
222
|
+
)
|
|
223
|
+
contiguous_kwargs = {
|
|
224
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
|
225
|
+
for k, v in kwargs.items()
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
tensor = None
|
|
229
|
+
for arg in args:
|
|
230
|
+
if isinstance(arg, torch.Tensor):
|
|
231
|
+
tensor = arg
|
|
232
|
+
break
|
|
233
|
+
if tensor is None:
|
|
234
|
+
for value in kwargs.values():
|
|
235
|
+
if isinstance(value, torch.Tensor):
|
|
236
|
+
tensor = value
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
if tensor is not None:
|
|
240
|
+
ctx = custom_device_ctx(tensor.device.index)
|
|
241
|
+
else:
|
|
242
|
+
ctx = contextlib.nullcontext()
|
|
243
|
+
|
|
244
|
+
with ctx:
|
|
245
|
+
return fn(*contiguous_args, **contiguous_kwargs)
|
|
246
|
+
|
|
247
|
+
return wrapper
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
cmake_minimum_required(VERSION 3.18)
|
|
2
|
+
project(wkv7 LANGUAGES CXX CUDA)
|
|
3
|
+
|
|
4
|
+
find_package(CUDAToolkit REQUIRED)
|
|
5
|
+
|
|
6
|
+
# ---------- 1. 找到 Python ----------
|
|
7
|
+
find_package(Python3 REQUIRED COMPONENTS Interpreter)
|
|
8
|
+
|
|
9
|
+
# ---------- 2. 取 XLA 头文件路径 ----------
|
|
10
|
+
execute_process(
|
|
11
|
+
COMMAND "${Python3_EXECUTABLE}" -c "from jax import ffi; print(ffi.include_dir())"
|
|
12
|
+
OUTPUT_VARIABLE XLA_INCLUDE_DIR
|
|
13
|
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
|
14
|
+
)
|
|
15
|
+
if(NOT XLA_INCLUDE_DIR)
|
|
16
|
+
message(FATAL_ERROR "Cannot get XLA include dir from jax.ffi")
|
|
17
|
+
endif()
|
|
18
|
+
message(STATUS "XLA include directory: ${XLA_INCLUDE_DIR}")
|
|
19
|
+
|
|
20
|
+
# ---------- 3. 生成共享库 ----------
|
|
21
|
+
add_library(wkv7 SHARED wkv7_ffi.cu)
|
|
22
|
+
|
|
23
|
+
# 3-1. 头文件搜索路径
|
|
24
|
+
target_include_directories(wkv7 PRIVATE ${XLA_INCLUDE_DIR})
|
|
25
|
+
|
|
26
|
+
# 3-2. 链接 CUDA 运行时
|
|
27
|
+
target_link_libraries(wkv7 PRIVATE CUDA::cudart)
|
|
28
|
+
|
|
29
|
+
# 3-3. 关键:C++17 / CUDA17 标准
|
|
30
|
+
target_compile_features(wkv7 PUBLIC cxx_std_17)
|
|
31
|
+
set_target_properties(wkv7 PROPERTIES
|
|
32
|
+
CUDA_STANDARD 17
|
|
33
|
+
CUDA_SEPARABLE_COMPILATION ON
|
|
34
|
+
POSITION_INDEPENDENT_CODE ON
|
|
35
|
+
PREFIX "" # 去掉默认的 "lib" 前缀
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# ---------- 4. 安装 ----------
|
|
39
|
+
# 把 .so 直接装到源码目录(与 wkv7_jax.py 同一级),方便 ctypes.CDLL 加载
|
|
40
|
+
install(TARGETS wkv7
|
|
41
|
+
LIBRARY DESTINATION "${CMAKE_SOURCE_DIR}"
|
|
42
|
+
RUNTIME DESTINATION "${CMAKE_SOURCE_DIR}") # Windows 用 RUNTIME
|
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
#include <cuda_bf16.h>
|
|
2
|
+
#include <cuda_runtime.h>
|
|
3
|
+
#include <xla/ffi/api/ffi.h>
|
|
4
|
+
#include <vector>
|
|
5
|
+
#include <cstdint>
|
|
6
|
+
// ref link:https://github.com/BlinkDL/RWKV-CUDA/tree/main/rwkv7_fast_fused
|
|
7
|
+
namespace ffi = xla::ffi;
|
|
8
|
+
|
|
9
|
+
/* -------------------- 类型别名 -------------------- */
|
|
10
|
+
using bf = __nv_bfloat16;
|
|
11
|
+
|
|
12
|
+
/* -------------------- 设备端辅助 -------------------- */
|
|
13
|
+
__device__ inline float to_float(const bf &u) {
|
|
14
|
+
return __bfloat162float(u);
|
|
15
|
+
}
|
|
16
|
+
__device__ inline bf to_bf(const float &u) {
|
|
17
|
+
return __float2bfloat16_rn(u);
|
|
18
|
+
}
|
|
19
|
+
typedef bf *__restrict__ F_;
|
|
20
|
+
|
|
21
|
+
/* -------------------- Kernel -------------------- */
|
|
22
|
+
// 【优化1】模板化 + launch_bounds,提升 Occupancy
|
|
23
|
+
template<int C> __launch_bounds__(C, 2)
|
|
24
|
+
__global__ void forward_kernel(int T, int H,
|
|
25
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
26
|
+
bf *y_, float *s_, float *sa_, float *h0_) {
|
|
27
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
28
|
+
float state[C] = {0};
|
|
29
|
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
|
30
|
+
|
|
31
|
+
int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
32
|
+
|
|
33
|
+
#pragma unroll
|
|
34
|
+
for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
|
|
35
|
+
|
|
36
|
+
for (int t = 0; t < T; ++t) {
|
|
37
|
+
// 【优化2】强制 int64_t 防止溢出
|
|
38
|
+
int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
|
|
39
|
+
|
|
40
|
+
__syncthreads();
|
|
41
|
+
q[i] = to_float(q_[ind]);
|
|
42
|
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
|
43
|
+
k[i] = to_float(k_[ind]);
|
|
44
|
+
a[i] = to_float(a_[ind]);
|
|
45
|
+
b[i] = to_float(b_[ind]);
|
|
46
|
+
__syncthreads();
|
|
47
|
+
|
|
48
|
+
float sa = 0.f;
|
|
49
|
+
#pragma unroll
|
|
50
|
+
for (int j = 0; j < C; ++j) sa += a[j] * state[j];
|
|
51
|
+
sa_[ind] = sa;
|
|
52
|
+
|
|
53
|
+
float v_val = to_float(v_[ind]);
|
|
54
|
+
float y = 0.f;
|
|
55
|
+
#pragma unroll
|
|
56
|
+
for (int j = 0; j < C; ++j) {
|
|
57
|
+
float &s = state[j];
|
|
58
|
+
s = s * w[j] + sa * b[j] + k[j] * v_val;
|
|
59
|
+
y += s * q[j];
|
|
60
|
+
}
|
|
61
|
+
y_[ind] = to_bf(y);
|
|
62
|
+
|
|
63
|
+
if ((t + 1) % _CHUNK_LEN_ == 0) {
|
|
64
|
+
int64_t base = ((int64_t)bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
|
|
65
|
+
((int64_t)t / _CHUNK_LEN_) * C * C + i;
|
|
66
|
+
#pragma unroll
|
|
67
|
+
for (int j = 0; j < C; ++j) s_[base + j * C] = state[j];
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
// 【优化3】反向 Kernel:模板化 + launch_bounds + float4 向量加载
|
|
73
|
+
template<int C> __launch_bounds__(C, 2)
|
|
74
|
+
__global__ void backward_kernel(int T, int H,
|
|
75
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_,
|
|
76
|
+
float *s_, float *sa_, float *dht_, float *dh0_,
|
|
77
|
+
bf *dw_, bf *dq_, bf *dk_, bf *dv_, bf *da_, bf *db_) {
|
|
78
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
79
|
+
float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
|
|
80
|
+
|
|
81
|
+
int64_t dht_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
82
|
+
|
|
83
|
+
#pragma unroll
|
|
84
|
+
for (int j = 0; j < C; ++j) {
|
|
85
|
+
dstate[j] = dht_[dht_base + j];
|
|
86
|
+
dstateT[j] = dht_[dht_base + j];
|
|
87
|
+
}
|
|
88
|
+
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
|
|
89
|
+
float qi, wi, ki, ai, bi, dyi;
|
|
90
|
+
|
|
91
|
+
for (int t = T - 1; t >= 0; --t) {
|
|
92
|
+
int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
|
|
93
|
+
|
|
94
|
+
__syncthreads();
|
|
95
|
+
q[i] = qi = to_float(q_[ind]);
|
|
96
|
+
float wi_fac = -__expf(to_float(w_[ind]));
|
|
97
|
+
w[i] = wi = __expf(wi_fac);
|
|
98
|
+
k[i] = ki = to_float(k_[ind]);
|
|
99
|
+
a[i] = ai = to_float(a_[ind]);
|
|
100
|
+
b[i] = bi = to_float(b_[ind]);
|
|
101
|
+
v[i] = to_float(v_[ind]);
|
|
102
|
+
dy[i] = dyi = to_float(dy_[ind]);
|
|
103
|
+
sa[i] = sa_[ind];
|
|
104
|
+
__syncthreads();
|
|
105
|
+
|
|
106
|
+
if ((t + 1) % _CHUNK_LEN_ == 0) {
|
|
107
|
+
int64_t base = ((int64_t)bb * H + hh) * (T / _CHUNK_LEN_) * C * C +
|
|
108
|
+
((int64_t)t / _CHUNK_LEN_) * C * C + i * C;
|
|
109
|
+
|
|
110
|
+
// 【优化4】float4 向量加载,带宽利用率提升 4倍
|
|
111
|
+
const float4* s4 = (const float4*)(s_ + base);
|
|
112
|
+
#pragma unroll
|
|
113
|
+
for (int j4 = 0; j4 < C / 4; ++j4) {
|
|
114
|
+
float4 q_vec = s4[j4];
|
|
115
|
+
const int j = j4 * 4;
|
|
116
|
+
stateT[j + 0] = q_vec.x;
|
|
117
|
+
stateT[j + 1] = q_vec.y;
|
|
118
|
+
stateT[j + 2] = q_vec.z;
|
|
119
|
+
stateT[j + 3] = q_vec.w;
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
float dq_val = 0.f;
|
|
124
|
+
#pragma unroll
|
|
125
|
+
for (int j = 0; j < C; ++j) dq_val += stateT[j] * dy[j];
|
|
126
|
+
dq_[ind] = to_bf(dq_val);
|
|
127
|
+
|
|
128
|
+
float iwi = 1.f / (wi + 1e-6f);
|
|
129
|
+
#pragma unroll
|
|
130
|
+
for (int j = 0; j < C; ++j) {
|
|
131
|
+
stateT[j] = (stateT[j] - ki * v[j] - bi * sa[j]) * iwi;
|
|
132
|
+
dstate[j] += dyi * q[j];
|
|
133
|
+
dstateT[j] += qi * dy[j];
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
float dw = 0.f, dk = 0.f, dv = 0.f, db = 0.f, dSb = 0.f;
|
|
137
|
+
#pragma unroll
|
|
138
|
+
for (int j = 0; j < C; ++j) {
|
|
139
|
+
dw += dstateT[j] * stateT[j];
|
|
140
|
+
dk += dstateT[j] * v[j];
|
|
141
|
+
dv += dstate[j] * k[j];
|
|
142
|
+
dSb += dstate[j] * b[j];
|
|
143
|
+
db += dstateT[j] * sa[j];
|
|
144
|
+
}
|
|
145
|
+
dw_[ind] = to_bf(dw * wi * wi_fac);
|
|
146
|
+
dk_[ind] = to_bf(dk);
|
|
147
|
+
dv_[ind] = to_bf(dv);
|
|
148
|
+
db_[ind] = to_bf(db);
|
|
149
|
+
|
|
150
|
+
__syncthreads();
|
|
151
|
+
dSb_shared[i] = dSb;
|
|
152
|
+
__syncthreads();
|
|
153
|
+
|
|
154
|
+
float da = 0.f;
|
|
155
|
+
#pragma unroll
|
|
156
|
+
for (int j = 0; j < C; ++j) da += stateT[j] * dSb_shared[j];
|
|
157
|
+
da_[ind] = to_bf(da);
|
|
158
|
+
|
|
159
|
+
#pragma unroll
|
|
160
|
+
for (int j = 0; j < C; ++j) {
|
|
161
|
+
dstate[j] = dstate[j] * w[j] + dSb * a[j];
|
|
162
|
+
dstateT[j] = dstateT[j] * wi + ai * dSb_shared[j];
|
|
163
|
+
if (t == 0) dh0_[dht_base + j] = dstate[j];
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
/* -------------------- 推理专用 Kernel -------------------- */
|
|
169
|
+
template<int C> __launch_bounds__(C, 2)
|
|
170
|
+
__global__ void forward_inference_kernel(int T, int H,
|
|
171
|
+
F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_,
|
|
172
|
+
bf *y_, float *s_, float *h0_) {
|
|
173
|
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
|
174
|
+
float state[C] = {0};
|
|
175
|
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
|
176
|
+
|
|
177
|
+
int64_t h0_base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
178
|
+
|
|
179
|
+
#pragma unroll
|
|
180
|
+
for (int j = 0; j < C; ++j) state[j] = h0_[h0_base + j];
|
|
181
|
+
|
|
182
|
+
for (int t = 0; t < T; ++t) {
|
|
183
|
+
int64_t ind = (int64_t)bb * T * H * C + (int64_t)t * H * C + hh * C + i;
|
|
184
|
+
|
|
185
|
+
__syncthreads();
|
|
186
|
+
q[i] = to_float(q_[ind]);
|
|
187
|
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
|
188
|
+
k[i] = to_float(k_[ind]);
|
|
189
|
+
a[i] = to_float(a_[ind]);
|
|
190
|
+
b[i] = to_float(b_[ind]);
|
|
191
|
+
__syncthreads();
|
|
192
|
+
|
|
193
|
+
float sa = 0.f;
|
|
194
|
+
#pragma unroll
|
|
195
|
+
for (int j = 0; j < C; ++j) sa += a[j] * state[j];
|
|
196
|
+
|
|
197
|
+
float v_val = to_float(v_[ind]);
|
|
198
|
+
float y = 0.f;
|
|
199
|
+
#pragma unroll
|
|
200
|
+
for (int j = 0; j < C; ++j) {
|
|
201
|
+
float &s = state[j];
|
|
202
|
+
s = s * w[j] + sa * b[j] + k[j] * v_val;
|
|
203
|
+
y += s * q[j];
|
|
204
|
+
}
|
|
205
|
+
y_[ind] = to_bf(y);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
int64_t base = ((int64_t)bb * H + hh) * C * C + i * C;
|
|
209
|
+
#pragma unroll
|
|
210
|
+
for (int j = 0; j < C; ++j) s_[base + j] = state[j];
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
/* -------------------- Host 函数(参数名已统一) -------------------- */
|
|
214
|
+
static ffi::Error WKV7FwdHost(
|
|
215
|
+
cudaStream_t stream,
|
|
216
|
+
ffi::Buffer<ffi::BF16> w,
|
|
217
|
+
ffi::Buffer<ffi::BF16> q,
|
|
218
|
+
ffi::Buffer<ffi::BF16> k,
|
|
219
|
+
ffi::Buffer<ffi::BF16> v,
|
|
220
|
+
ffi::Buffer<ffi::BF16> a, // 原'z',直接对应 kernel 的 a_
|
|
221
|
+
ffi::Buffer<ffi::BF16> b, // 原'a',直接对应 kernel 的 b_
|
|
222
|
+
ffi::Buffer<ffi::F32> h0,
|
|
223
|
+
ffi::ResultBuffer<ffi::BF16> y,
|
|
224
|
+
ffi::ResultBuffer<ffi::F32> s,
|
|
225
|
+
ffi::ResultBuffer<ffi::F32> sa)
|
|
226
|
+
{
|
|
227
|
+
constexpr int C = _C_;
|
|
228
|
+
auto dims = w.dimensions();
|
|
229
|
+
int B = dims[0], T = dims[1], H = dims[2];
|
|
230
|
+
dim3 block(C);
|
|
231
|
+
dim3 grid(H, B);
|
|
232
|
+
|
|
233
|
+
// 【关键】模板实例化调用,参数直接映射
|
|
234
|
+
forward_kernel<_C_><<<grid, block, 0, stream>>>(
|
|
235
|
+
T, H,
|
|
236
|
+
reinterpret_cast<bf *>(w.typed_data()),
|
|
237
|
+
reinterpret_cast<bf *>(q.typed_data()),
|
|
238
|
+
reinterpret_cast<bf *>(k.typed_data()),
|
|
239
|
+
reinterpret_cast<bf *>(v.typed_data()),
|
|
240
|
+
reinterpret_cast<bf *>(a.typed_data()), // 直接映射到 a_
|
|
241
|
+
reinterpret_cast<bf *>(b.typed_data()), // 直接映射到 b_
|
|
242
|
+
reinterpret_cast<bf *>(y->typed_data()),
|
|
243
|
+
s->typed_data(),
|
|
244
|
+
sa->typed_data(),
|
|
245
|
+
h0.typed_data());
|
|
246
|
+
|
|
247
|
+
cudaError_t err = cudaGetLastError();
|
|
248
|
+
if (err != cudaSuccess)
|
|
249
|
+
return ffi::Error::Internal(
|
|
250
|
+
std::string("CUDA forward_kernel error: ") + cudaGetErrorString(err));
|
|
251
|
+
return ffi::Error::Success();
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
static ffi::Error WKV7BwdHost(
|
|
255
|
+
cudaStream_t stream,
|
|
256
|
+
ffi::Buffer<ffi::BF16> w,
|
|
257
|
+
ffi::Buffer<ffi::BF16> q,
|
|
258
|
+
ffi::Buffer<ffi::BF16> k,
|
|
259
|
+
ffi::Buffer<ffi::BF16> v,
|
|
260
|
+
ffi::Buffer<ffi::BF16> a, // 原'z',直接对应 kernel 的 a_
|
|
261
|
+
ffi::Buffer<ffi::BF16> b, // 原'a',直接对应 kernel 的 b_
|
|
262
|
+
ffi::Buffer<ffi::BF16> dy,
|
|
263
|
+
ffi::Buffer<ffi::F32> s,
|
|
264
|
+
ffi::Buffer<ffi::F32> sa,
|
|
265
|
+
ffi::Buffer<ffi::F32> dht,
|
|
266
|
+
ffi::ResultBuffer<ffi::F32> dh0,
|
|
267
|
+
ffi::ResultBuffer<ffi::BF16> dw,
|
|
268
|
+
ffi::ResultBuffer<ffi::BF16> dq,
|
|
269
|
+
ffi::ResultBuffer<ffi::BF16> dk,
|
|
270
|
+
ffi::ResultBuffer<ffi::BF16> dv,
|
|
271
|
+
ffi::ResultBuffer<ffi::BF16> da,
|
|
272
|
+
ffi::ResultBuffer<ffi::BF16> db)
|
|
273
|
+
{
|
|
274
|
+
auto dims = w.dimensions();
|
|
275
|
+
int B = dims[0], T = dims[1], H = dims[2];
|
|
276
|
+
constexpr int C = _C_;
|
|
277
|
+
dim3 block(C);
|
|
278
|
+
dim3 grid(H, B);
|
|
279
|
+
|
|
280
|
+
// 【关键】模板实例化调用,参数直接映射
|
|
281
|
+
backward_kernel<_C_><<<grid, block, 0, stream>>>(
|
|
282
|
+
T, H,
|
|
283
|
+
reinterpret_cast<bf *>(w.typed_data()),
|
|
284
|
+
reinterpret_cast<bf *>(q.typed_data()),
|
|
285
|
+
reinterpret_cast<bf *>(k.typed_data()),
|
|
286
|
+
reinterpret_cast<bf *>(v.typed_data()),
|
|
287
|
+
reinterpret_cast<bf *>(a.typed_data()), // 直接映射到 a_
|
|
288
|
+
reinterpret_cast<bf *>(b.typed_data()), // 直接映射到 b_
|
|
289
|
+
reinterpret_cast<bf *>(dy.typed_data()),
|
|
290
|
+
s.typed_data(),
|
|
291
|
+
sa.typed_data(),
|
|
292
|
+
dht.typed_data(),
|
|
293
|
+
dh0->typed_data(),
|
|
294
|
+
reinterpret_cast<bf *>(dw->typed_data()),
|
|
295
|
+
reinterpret_cast<bf *>(dq->typed_data()),
|
|
296
|
+
reinterpret_cast<bf *>(dk->typed_data()),
|
|
297
|
+
reinterpret_cast<bf *>(dv->typed_data()),
|
|
298
|
+
reinterpret_cast<bf *>(da->typed_data()),
|
|
299
|
+
reinterpret_cast<bf *>(db->typed_data()));
|
|
300
|
+
|
|
301
|
+
cudaError_t err = cudaGetLastError();
|
|
302
|
+
if (err != cudaSuccess)
|
|
303
|
+
return ffi::Error::Internal(
|
|
304
|
+
std::string("CUDA backward_kernel error: ") + cudaGetErrorString(err));
|
|
305
|
+
return ffi::Error::Success();
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
static ffi::Error WKV7InferenceHost(
|
|
309
|
+
cudaStream_t stream,
|
|
310
|
+
ffi::Buffer<ffi::BF16> w,
|
|
311
|
+
ffi::Buffer<ffi::BF16> q,
|
|
312
|
+
ffi::Buffer<ffi::BF16> k,
|
|
313
|
+
ffi::Buffer<ffi::BF16> v,
|
|
314
|
+
ffi::Buffer<ffi::BF16> a, // 直接对应 kernel 的 a_
|
|
315
|
+
ffi::Buffer<ffi::BF16> b, // 直接对应 kernel 的 b_
|
|
316
|
+
ffi::Buffer<ffi::F32> h0,
|
|
317
|
+
ffi::ResultBuffer<ffi::BF16> y,
|
|
318
|
+
ffi::ResultBuffer<ffi::F32> s)
|
|
319
|
+
{
|
|
320
|
+
constexpr int C = _C_;
|
|
321
|
+
auto dims = w.dimensions();
|
|
322
|
+
int B = dims[0], T = dims[1], H = dims[2];
|
|
323
|
+
dim3 block(C);
|
|
324
|
+
dim3 grid(H, B);
|
|
325
|
+
|
|
326
|
+
// 【关键】模板实例化调用,参数直接映射
|
|
327
|
+
forward_inference_kernel<_C_><<<grid, block, 0, stream>>>(
|
|
328
|
+
T, H,
|
|
329
|
+
reinterpret_cast<bf *>(w.typed_data()),
|
|
330
|
+
reinterpret_cast<bf *>(q.typed_data()),
|
|
331
|
+
reinterpret_cast<bf *>(k.typed_data()),
|
|
332
|
+
reinterpret_cast<bf *>(v.typed_data()),
|
|
333
|
+
reinterpret_cast<bf *>(a.typed_data()), // 直接映射到 a_
|
|
334
|
+
reinterpret_cast<bf *>(b.typed_data()), // 直接映射到 b_
|
|
335
|
+
reinterpret_cast<bf *>(y->typed_data()),
|
|
336
|
+
s->typed_data(),
|
|
337
|
+
h0.typed_data());
|
|
338
|
+
|
|
339
|
+
cudaError_t err = cudaGetLastError();
|
|
340
|
+
if (err != cudaSuccess)
|
|
341
|
+
return ffi::Error::Internal(
|
|
342
|
+
std::string("CUDA forward_inference_kernel error: ") + cudaGetErrorString(err));
|
|
343
|
+
return ffi::Error::Success();
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
/* -------------------- FFI 注册(参数名已对齐) -------------------- */
|
|
347
|
+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
|
348
|
+
Wkv7Fwd, WKV7FwdHost,
|
|
349
|
+
ffi::Ffi::Bind()
|
|
350
|
+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
|
351
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // w
|
|
352
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // q
|
|
353
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // k
|
|
354
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // v
|
|
355
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // a (原z)
|
|
356
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // b (原a)
|
|
357
|
+
.Arg<ffi::Buffer<ffi::F32>>() // h0
|
|
358
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // y
|
|
359
|
+
.Ret<ffi::Buffer<ffi::F32>>() // s
|
|
360
|
+
.Ret<ffi::Buffer<ffi::F32>>() // sa
|
|
361
|
+
, {ffi::Traits::kCmdBufferCompatible});
|
|
362
|
+
|
|
363
|
+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
|
364
|
+
Wkv7Bwd, WKV7BwdHost,
|
|
365
|
+
ffi::Ffi::Bind()
|
|
366
|
+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
|
367
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // w
|
|
368
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // q
|
|
369
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // k
|
|
370
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // v
|
|
371
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // a (原z)
|
|
372
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // b (原a)
|
|
373
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // dy
|
|
374
|
+
.Arg<ffi::Buffer<ffi::F32>>() // s
|
|
375
|
+
.Arg<ffi::Buffer<ffi::F32>>() // sa
|
|
376
|
+
.Arg<ffi::Buffer<ffi::F32>>() // dht
|
|
377
|
+
.Ret<ffi::Buffer<ffi::F32>>() // dh0
|
|
378
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dw
|
|
379
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dq
|
|
380
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dk
|
|
381
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // dv
|
|
382
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // da
|
|
383
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // db
|
|
384
|
+
, {ffi::Traits::kCmdBufferCompatible});
|
|
385
|
+
|
|
386
|
+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
|
387
|
+
Wkv7Inference, WKV7InferenceHost,
|
|
388
|
+
ffi::Ffi::Bind()
|
|
389
|
+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
|
390
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // w
|
|
391
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // q
|
|
392
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // k
|
|
393
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // v
|
|
394
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // a
|
|
395
|
+
.Arg<ffi::Buffer<ffi::BF16>>() // b
|
|
396
|
+
.Arg<ffi::Buffer<ffi::F32>>() // h0
|
|
397
|
+
.Ret<ffi::Buffer<ffi::BF16>>() // y
|
|
398
|
+
.Ret<ffi::Buffer<ffi::F32>>() // s (final state)
|
|
399
|
+
, {ffi::Traits::kCmdBufferCompatible});
|