rwkv-ops 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rwkv-ops might be problematic. Click here for more details.
- rwkv_ops/__init__.py +26 -0
- rwkv_ops/rwkv7_kernel/__init__.py +153 -0
- rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
- rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
- rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
- rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
- rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
- rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
- rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
- rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
- rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
- rwkv_ops-0.1.0.dist-info/METADATA +118 -0
- rwkv_ops-0.1.0.dist-info/RECORD +43 -0
- rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
- rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
rwkv_ops/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "triton")
|
|
4
|
+
KERAS_BACKEND = os.environ.get("KERAS_BACKEND")
|
|
5
|
+
BACKEND = os.environ.get("KERNEL_BACKEND")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
if KERAS_BACKEND is not None:
|
|
9
|
+
BACKEND = KERAS_BACKEND
|
|
10
|
+
elif BACKEND is not None:
|
|
11
|
+
os.environ["KERAS_BACKEND"] = BACKEND
|
|
12
|
+
else:
|
|
13
|
+
import torch
|
|
14
|
+
import keras
|
|
15
|
+
|
|
16
|
+
BACKEND = "torch"
|
|
17
|
+
os.environ["KERAS_BACKEND"] = BACKEND
|
|
18
|
+
keras.config.set_backend("torch")
|
|
19
|
+
assert KERNEL_TYPE in ["triton", "cuda", "native"]
|
|
20
|
+
assert BACKEND in ["torch", "jax", "numpy", "tensorflow"]
|
|
21
|
+
from .rwkv7_kernel import get_generalized_delta_rule
|
|
22
|
+
|
|
23
|
+
generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
|
|
24
|
+
KERNEL_TYPE=KERNEL_TYPE
|
|
25
|
+
)
|
|
26
|
+
rwkv7_op = generalized_delta_rule
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
from distutils.util import strtobool
|
|
3
|
+
from keras import ops
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def transpose_head(x, head_first):
|
|
7
|
+
if head_first:
|
|
8
|
+
return ops.transpose(x, (0, 2, 1, 3))
|
|
9
|
+
else:
|
|
10
|
+
return x
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
|
|
14
|
+
USE_KERNEL = False
|
|
15
|
+
if keras.config.backend() == "torch":
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
if KERNEL_TYPE.lower() == "triton":
|
|
19
|
+
from .torch_op import generalized_delta_rule
|
|
20
|
+
|
|
21
|
+
USE_KERNEL = True
|
|
22
|
+
elif KERNEL_TYPE.lower() == "cuda":
|
|
23
|
+
CHUNK_LEN = 16
|
|
24
|
+
USE_KERNEL = True
|
|
25
|
+
from torch.utils.cpp_extension import load
|
|
26
|
+
import os
|
|
27
|
+
|
|
28
|
+
flags = [
|
|
29
|
+
"-res-usage",
|
|
30
|
+
f"-D_C_={HEAD_SIZE}",
|
|
31
|
+
f"-D_CHUNK_LEN_={CHUNK_LEN}",
|
|
32
|
+
"--use_fast_math",
|
|
33
|
+
"-O3",
|
|
34
|
+
"-Xptxas -O3",
|
|
35
|
+
"--extra-device-vectorization",
|
|
36
|
+
]
|
|
37
|
+
# 获取当前文件的绝对路径
|
|
38
|
+
current_file_path = os.path.abspath(__file__)
|
|
39
|
+
|
|
40
|
+
# 获取当前文件的目录路径
|
|
41
|
+
current_dir_path = os.path.dirname(current_file_path)
|
|
42
|
+
|
|
43
|
+
# 获取上一级目录的路径
|
|
44
|
+
parent_dir_path = os.path.abspath(
|
|
45
|
+
os.path.join(current_dir_path, os.path.pardir)
|
|
46
|
+
)
|
|
47
|
+
load(
|
|
48
|
+
name="wind_backstepping",
|
|
49
|
+
sources=[
|
|
50
|
+
os.path.join(parent_dir_path, "cuda_kernel/wkv7_cuda.cu"),
|
|
51
|
+
os.path.join(parent_dir_path, "cuda_kernel/wkv7_op.cpp"),
|
|
52
|
+
],
|
|
53
|
+
is_python_module=False,
|
|
54
|
+
verbose=True,
|
|
55
|
+
extra_cuda_cflags=flags,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
class WindBackstepping(torch.autograd.Function):
|
|
59
|
+
@staticmethod
|
|
60
|
+
def forward(ctx, w, q, k, v, z, b):
|
|
61
|
+
B, T, H, C = w.shape
|
|
62
|
+
DTYPE = q.dtype
|
|
63
|
+
q = ops.cast(q, "bfloat16")
|
|
64
|
+
k = ops.cast(k, "bfloat16")
|
|
65
|
+
v = ops.cast(v, "bfloat16")
|
|
66
|
+
z = ops.cast(z, "bfloat16")
|
|
67
|
+
b = ops.cast(b, "bfloat16")
|
|
68
|
+
w = ops.cast(w, "bfloat16")
|
|
69
|
+
assert T % CHUNK_LEN == 0
|
|
70
|
+
assert all(i.is_contiguous() for i in [w, q, k, v, z, b])
|
|
71
|
+
y = torch.empty_like(v)
|
|
72
|
+
s = torch.empty(
|
|
73
|
+
B, H, T // CHUNK_LEN, C, C, dtype=torch.float32, device=w.device
|
|
74
|
+
)
|
|
75
|
+
sa = torch.empty(B, T, H, C, dtype=torch.float32, device=w.device)
|
|
76
|
+
torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa)
|
|
77
|
+
ctx.save_for_backward(w, q, k, v, z, b, s, sa)
|
|
78
|
+
return ops.cast(y, DTYPE)
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def backward(ctx, dy):
|
|
82
|
+
DTYPE = dy.dtype
|
|
83
|
+
dy = ops.cast(dy, torch.bfloat16)
|
|
84
|
+
dy = dy.contiguous()
|
|
85
|
+
assert all(i.dtype == torch.bfloat16 for i in [dy])
|
|
86
|
+
assert all(i.is_contiguous() for i in [dy])
|
|
87
|
+
w, q, k, v, z, b, s, sa = ctx.saved_tensors
|
|
88
|
+
dw, dq, dk, dv, dz, db = [
|
|
89
|
+
torch.empty_like(x) for x in [w, q, k, v, z, b]
|
|
90
|
+
]
|
|
91
|
+
torch.ops.wind_backstepping.backward(
|
|
92
|
+
w, q, k, v, z, b, dy, s, sa, dw, dq, dk, dv, dz, db
|
|
93
|
+
)
|
|
94
|
+
return (
|
|
95
|
+
ops.cast(dw, DTYPE),
|
|
96
|
+
ops.cast(dq, DTYPE),
|
|
97
|
+
ops.cast(dk, DTYPE),
|
|
98
|
+
ops.cast(dv, DTYPE),
|
|
99
|
+
ops.cast(dz, DTYPE),
|
|
100
|
+
ops.cast(db, DTYPE),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def RUN_CUDA_RWKV7g(q, w, k, v, a, b):
|
|
104
|
+
B, T, H, C = q.shape
|
|
105
|
+
q = q.contiguous()
|
|
106
|
+
w = w.contiguous()
|
|
107
|
+
k = k.contiguous()
|
|
108
|
+
v = v.contiguous()
|
|
109
|
+
a = a.contiguous()
|
|
110
|
+
b = b.contiguous()
|
|
111
|
+
return WindBackstepping.apply(w, q, k, v, a, b).view(B, T, H * C)
|
|
112
|
+
|
|
113
|
+
def generalized_delta_rule(
|
|
114
|
+
r: torch.Tensor,
|
|
115
|
+
w: torch.Tensor,
|
|
116
|
+
k: torch.Tensor,
|
|
117
|
+
v: torch.Tensor,
|
|
118
|
+
a: torch.Tensor,
|
|
119
|
+
b: torch.Tensor,
|
|
120
|
+
initial_state: torch.Tensor = None,
|
|
121
|
+
output_final_state: bool = True,
|
|
122
|
+
head_first: bool = False,
|
|
123
|
+
use_chunk: bool = True,
|
|
124
|
+
):
|
|
125
|
+
r = transpose_head(r, head_first)
|
|
126
|
+
k = transpose_head(k, head_first)
|
|
127
|
+
v = transpose_head(v, head_first)
|
|
128
|
+
a = transpose_head(a, head_first)
|
|
129
|
+
b = transpose_head(b, head_first)
|
|
130
|
+
w = transpose_head(w, head_first)
|
|
131
|
+
return RUN_CUDA_RWKV7g(r, w, k, v, a, b), None
|
|
132
|
+
else:
|
|
133
|
+
from .native_keras_op import generalized_delta_rule
|
|
134
|
+
|
|
135
|
+
USE_KERNEL = False
|
|
136
|
+
elif keras.config.backend() == "jax":
|
|
137
|
+
from jax.lib import xla_bridge
|
|
138
|
+
import jax
|
|
139
|
+
import os
|
|
140
|
+
|
|
141
|
+
if (
|
|
142
|
+
xla_bridge.get_backend().platform == "gpu"
|
|
143
|
+
and KERNEL_TYPE.lower() == "triton"
|
|
144
|
+
):
|
|
145
|
+
from .jax_op import generalized_delta_rule
|
|
146
|
+
|
|
147
|
+
USE_KERNEL = True
|
|
148
|
+
else:
|
|
149
|
+
from .native_keras_op import generalized_delta_rule
|
|
150
|
+
|
|
151
|
+
else:
|
|
152
|
+
from .native_keras_op import generalized_delta_rule
|
|
153
|
+
return generalized_delta_rule, USE_KERNEL
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from typing import Literal
|
|
4
|
+
import functools
|
|
5
|
+
import triton
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@lru_cache(maxsize=None)
|
|
11
|
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
|
12
|
+
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
|
13
|
+
"multiprocessor_count"
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@lru_cache(maxsize=None)
|
|
18
|
+
def get_available_device() -> str:
|
|
19
|
+
try:
|
|
20
|
+
return triton.runtime.driver.active.get_current_target().backend
|
|
21
|
+
except BaseException:
|
|
22
|
+
import warnings
|
|
23
|
+
|
|
24
|
+
warnings.warn(
|
|
25
|
+
("Triton is not supported on current platform, roll back to CPU."),
|
|
26
|
+
stacklevel=1,
|
|
27
|
+
)
|
|
28
|
+
return "cpu"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@lru_cache(maxsize=None)
|
|
32
|
+
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
|
33
|
+
device = get_available_device()
|
|
34
|
+
if device == "cuda":
|
|
35
|
+
return "nvidia"
|
|
36
|
+
elif device == "hip":
|
|
37
|
+
return "amd"
|
|
38
|
+
elif device == "xpu":
|
|
39
|
+
return "intel"
|
|
40
|
+
else:
|
|
41
|
+
return device
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
|
45
|
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
|
46
|
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
|
47
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
48
|
+
|
|
49
|
+
device_platform = _check_platform()
|
|
50
|
+
|
|
51
|
+
is_intel = device_platform == "intel"
|
|
52
|
+
is_nvidia = device_platform == "nvidia"
|
|
53
|
+
is_amd = device_platform == "amd"
|
|
54
|
+
|
|
55
|
+
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
59
|
+
|
|
60
|
+
is_intel_a770 = False
|
|
61
|
+
device = jax.devices()
|
|
62
|
+
is_tf32_supported = is_nvidia
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_all_max_shared_memory():
|
|
66
|
+
return [
|
|
67
|
+
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
|
|
68
|
+
for i in range(len(jax.devices()))
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
device_shared_mem_list = get_all_max_shared_memory()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@lru_cache(maxsize=None)
|
|
76
|
+
def is_triton_shared_mem_enough(
|
|
77
|
+
max_shared_mem: int = 102400, tensor_idx: int = 0
|
|
78
|
+
) -> bool:
|
|
79
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
80
|
+
return max_shared_memory >= max_shared_mem
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
device_capacity = is_triton_shared_mem_enough()
|
|
84
|
+
|
|
85
|
+
from enum import Enum
|
|
86
|
+
import contextlib
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _cpu_device_warning():
|
|
90
|
+
import warnings
|
|
91
|
+
|
|
92
|
+
warnings.warn(
|
|
93
|
+
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def get_all_max_shared_mem():
|
|
98
|
+
try:
|
|
99
|
+
return [
|
|
100
|
+
triton.runtime.driver.active.utils.get_device_properties(i)[
|
|
101
|
+
"max_shared_mem"
|
|
102
|
+
]
|
|
103
|
+
for i in range(len(jax.devices()))
|
|
104
|
+
]
|
|
105
|
+
except BaseException:
|
|
106
|
+
_cpu_device_warning()
|
|
107
|
+
return [-1]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class Backend(Enum):
|
|
111
|
+
ADA = 101376 # RTX 4090
|
|
112
|
+
AMPERE = 166912 # A100
|
|
113
|
+
HOPPER = 232448 # H100
|
|
114
|
+
DEFAULT = 102400 # Default
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_shared_memory(cls, arch: str) -> int:
|
|
118
|
+
try:
|
|
119
|
+
return cls[arch.upper()].value
|
|
120
|
+
except KeyError:
|
|
121
|
+
return cls.DEFAULT.value
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@lru_cache(maxsize=None)
|
|
125
|
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
|
126
|
+
try:
|
|
127
|
+
device_shared_mem_list = get_all_max_shared_mem()
|
|
128
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
129
|
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
|
130
|
+
except Exception:
|
|
131
|
+
return False
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def tensor_cache(fn):
|
|
135
|
+
"""
|
|
136
|
+
A decorator that caches the most recent result of a function with tensor inputs.
|
|
137
|
+
|
|
138
|
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
|
139
|
+
If the function is called again with the same input tensors, it will return the cached result.
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
fn (Callable[..., jax.Array]):
|
|
144
|
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Callable[..., jax.Array]:
|
|
148
|
+
A wrapped version of the input function with single-entry caching.
|
|
149
|
+
"""
|
|
150
|
+
last_args = None
|
|
151
|
+
last_kwargs = None
|
|
152
|
+
last_result = None
|
|
153
|
+
|
|
154
|
+
@functools.wraps(fn)
|
|
155
|
+
def wrapper(*args, **kwargs):
|
|
156
|
+
nonlocal last_args, last_kwargs, last_result
|
|
157
|
+
|
|
158
|
+
if last_args is not None and last_kwargs is not None:
|
|
159
|
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
|
160
|
+
if all(a is b for a, b in zip(args, last_args)) and all(
|
|
161
|
+
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
|
162
|
+
):
|
|
163
|
+
return last_result
|
|
164
|
+
|
|
165
|
+
result = fn(*args, **kwargs)
|
|
166
|
+
last_args, last_kwargs, last_result = args, kwargs, result
|
|
167
|
+
return result
|
|
168
|
+
|
|
169
|
+
return wrapper
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@tensor_cache
|
|
173
|
+
def prepare_lens(cu_seqlens):
|
|
174
|
+
return cu_seqlens[1:] - cu_seqlens[:-1]
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@tensor_cache
|
|
178
|
+
def prepare_chunk_indices(cu_seqlens, chunk_size: int):
|
|
179
|
+
indices = jnp.concatenate(
|
|
180
|
+
[
|
|
181
|
+
jnp.arange(n)
|
|
182
|
+
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
|
183
|
+
]
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return jnp.stack([jnp.cumsum(jnp.equal(indices, 0), 0) - 1, indices], 1)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def input_guard(fn):
|
|
190
|
+
"""
|
|
191
|
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
@functools.wraps(fn)
|
|
195
|
+
def wrapper(*args, **kwargs):
|
|
196
|
+
contiguous_args = (i for i in args)
|
|
197
|
+
contiguous_kwargs = {k: v for k, v in kwargs.items()}
|
|
198
|
+
|
|
199
|
+
tensor = None
|
|
200
|
+
for arg in args:
|
|
201
|
+
if isinstance(arg, jax.Array):
|
|
202
|
+
tensor = arg
|
|
203
|
+
break
|
|
204
|
+
if tensor is None:
|
|
205
|
+
for value in kwargs.values():
|
|
206
|
+
if isinstance(value, jax.Array):
|
|
207
|
+
tensor = value
|
|
208
|
+
break
|
|
209
|
+
|
|
210
|
+
if tensor is not None:
|
|
211
|
+
ctx = tensor.device
|
|
212
|
+
else:
|
|
213
|
+
ctx = contextlib.nullcontext()
|
|
214
|
+
|
|
215
|
+
with ctx:
|
|
216
|
+
return fn(*contiguous_args, **contiguous_kwargs)
|
|
217
|
+
|
|
218
|
+
return wrapper
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
is_intel_alchemist = False
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import os
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
import triton
|
|
7
|
+
from packaging import version
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=None)
|
|
12
|
+
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
|
13
|
+
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
|
14
|
+
"multiprocessor_count"
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache(maxsize=None)
|
|
19
|
+
def get_available_device() -> str:
|
|
20
|
+
try:
|
|
21
|
+
return triton.runtime.driver.active.get_current_target().backend
|
|
22
|
+
except BaseException:
|
|
23
|
+
import warnings
|
|
24
|
+
|
|
25
|
+
warnings.warn(
|
|
26
|
+
("Triton is not supported on current platform, roll back to CPU."),
|
|
27
|
+
stacklevel=1,
|
|
28
|
+
)
|
|
29
|
+
return "cpu"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@lru_cache(maxsize=None)
|
|
33
|
+
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
|
34
|
+
device = get_available_device()
|
|
35
|
+
if device == "cuda":
|
|
36
|
+
return "nvidia"
|
|
37
|
+
elif device == "hip":
|
|
38
|
+
return "amd"
|
|
39
|
+
elif device == "xpu":
|
|
40
|
+
return "intel"
|
|
41
|
+
else:
|
|
42
|
+
return device
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
|
46
|
+
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
|
47
|
+
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
|
48
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
49
|
+
|
|
50
|
+
device_platform = _check_platform()
|
|
51
|
+
|
|
52
|
+
is_intel = device_platform == "intel"
|
|
53
|
+
is_nvidia = device_platform == "nvidia"
|
|
54
|
+
is_amd = device_platform == "amd"
|
|
55
|
+
|
|
56
|
+
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@lru_cache(maxsize=None)
|
|
60
|
+
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
|
61
|
+
return version.parse(torch.__version__) >= version.parse(version_s)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
is_intel_a770 = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
|
65
|
+
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
|
66
|
+
device_torch_lib = getattr(torch, device)
|
|
67
|
+
if check_pytorch_version("2.4"):
|
|
68
|
+
device = "cuda" if device == "cpu" else device
|
|
69
|
+
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
|
70
|
+
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
|
71
|
+
|
|
72
|
+
def custom_device_ctx(index: int):
|
|
73
|
+
return device_torch_lib.device(index)
|
|
74
|
+
else:
|
|
75
|
+
assert device == "cuda", (
|
|
76
|
+
"Only cuda device is supported for PyTorch version < 2.4.0."
|
|
77
|
+
)
|
|
78
|
+
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
|
79
|
+
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
|
80
|
+
|
|
81
|
+
def custom_device_ctx(index: int):
|
|
82
|
+
return torch.cuda.device(index)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Nvidia Ampere or newer, haven't check AMD and intel yet.
|
|
86
|
+
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_all_max_shared_memory():
|
|
90
|
+
return [
|
|
91
|
+
triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
|
|
92
|
+
for i in range(device_torch_lib.device_count())
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
device_shared_mem_list = get_all_max_shared_memory()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@lru_cache(maxsize=None)
|
|
100
|
+
def is_triton_shared_mem_enough(
|
|
101
|
+
max_shared_mem: int = 102400, tensor_idx: int = 0
|
|
102
|
+
) -> bool:
|
|
103
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
104
|
+
return max_shared_memory >= max_shared_mem
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
device_capacity = is_triton_shared_mem_enough()
|
|
108
|
+
from enum import Enum
|
|
109
|
+
import contextlib
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _cpu_device_warning():
|
|
113
|
+
import warnings
|
|
114
|
+
|
|
115
|
+
warnings.warn(
|
|
116
|
+
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_all_max_shared_mem():
|
|
121
|
+
try:
|
|
122
|
+
return [
|
|
123
|
+
triton.runtime.driver.active.utils.get_device_properties(i)[
|
|
124
|
+
"max_shared_mem"
|
|
125
|
+
]
|
|
126
|
+
for i in range(device_torch_lib.device_count())
|
|
127
|
+
]
|
|
128
|
+
except BaseException:
|
|
129
|
+
_cpu_device_warning()
|
|
130
|
+
return [-1]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class Backend(Enum):
|
|
134
|
+
ADA = 101376 # RTX 4090
|
|
135
|
+
AMPERE = 166912 # A100
|
|
136
|
+
HOPPER = 232448 # H100
|
|
137
|
+
DEFAULT = 102400 # Default
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def get_shared_memory(cls, arch: str) -> int:
|
|
141
|
+
try:
|
|
142
|
+
return cls[arch.upper()].value
|
|
143
|
+
except KeyError:
|
|
144
|
+
return cls.DEFAULT.value
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@lru_cache(maxsize=None)
|
|
148
|
+
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
|
149
|
+
try:
|
|
150
|
+
device_shared_mem_list = get_all_max_shared_mem()
|
|
151
|
+
max_shared_memory = device_shared_mem_list[tensor_idx]
|
|
152
|
+
return max_shared_memory >= Backend.get_shared_memory(arch)
|
|
153
|
+
except Exception:
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def tensor_cache(fn):
|
|
158
|
+
"""
|
|
159
|
+
A decorator that caches the most recent result of a function with tensor inputs.
|
|
160
|
+
|
|
161
|
+
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
|
162
|
+
If the function is called again with the same input tensors, it will return the cached result.
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
fn (Callable[..., torch.Tensor]):
|
|
167
|
+
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Callable[..., torch.Tensor]:
|
|
171
|
+
A wrapped version of the input function with single-entry caching.
|
|
172
|
+
"""
|
|
173
|
+
last_args = None
|
|
174
|
+
last_kwargs = None
|
|
175
|
+
last_result = None
|
|
176
|
+
|
|
177
|
+
@functools.wraps(fn)
|
|
178
|
+
def wrapper(*args, **kwargs):
|
|
179
|
+
nonlocal last_args, last_kwargs, last_result
|
|
180
|
+
|
|
181
|
+
if last_args is not None and last_kwargs is not None:
|
|
182
|
+
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
|
183
|
+
if all(a is b for a, b in zip(args, last_args)) and all(
|
|
184
|
+
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
|
185
|
+
):
|
|
186
|
+
return last_result
|
|
187
|
+
|
|
188
|
+
result = fn(*args, **kwargs)
|
|
189
|
+
last_args, last_kwargs, last_result = args, kwargs, result
|
|
190
|
+
return result
|
|
191
|
+
|
|
192
|
+
return wrapper
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@tensor_cache
|
|
196
|
+
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
|
197
|
+
return cu_seqlens[1:] - cu_seqlens[:-1]
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@tensor_cache
|
|
201
|
+
def prepare_chunk_indices(
|
|
202
|
+
cu_seqlens: torch.LongTensor, chunk_size: int
|
|
203
|
+
) -> torch.LongTensor:
|
|
204
|
+
indices = torch.cat(
|
|
205
|
+
[
|
|
206
|
+
torch.arange(n)
|
|
207
|
+
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
|
208
|
+
]
|
|
209
|
+
)
|
|
210
|
+
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def input_guard(fn):
|
|
214
|
+
"""
|
|
215
|
+
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
@functools.wraps(fn)
|
|
219
|
+
def wrapper(*args, **kwargs):
|
|
220
|
+
contiguous_args = (
|
|
221
|
+
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
|
222
|
+
)
|
|
223
|
+
contiguous_kwargs = {
|
|
224
|
+
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
|
225
|
+
for k, v in kwargs.items()
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
tensor = None
|
|
229
|
+
for arg in args:
|
|
230
|
+
if isinstance(arg, torch.Tensor):
|
|
231
|
+
tensor = arg
|
|
232
|
+
break
|
|
233
|
+
if tensor is None:
|
|
234
|
+
for value in kwargs.values():
|
|
235
|
+
if isinstance(value, torch.Tensor):
|
|
236
|
+
tensor = value
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
if tensor is not None:
|
|
240
|
+
ctx = custom_device_ctx(tensor.device.index)
|
|
241
|
+
else:
|
|
242
|
+
ctx = contextlib.nullcontext()
|
|
243
|
+
|
|
244
|
+
with ctx:
|
|
245
|
+
return fn(*contiguous_args, **contiguous_kwargs)
|
|
246
|
+
|
|
247
|
+
return wrapper
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from ..jax_kernel.chunk_A_fwd import *
|
|
2
|
+
from ..jax_kernel.chunk_A_bwd import *
|
|
3
|
+
from ..jax_kernel.chunk_h_fwd import *
|
|
4
|
+
from ..jax_kernel.chunk_h_bwd import *
|
|
5
|
+
from ..jax_kernel.chunk_o_fwd import *
|
|
6
|
+
from ..jax_kernel.chunk_o_bwd import *
|
|
7
|
+
from ..jax_kernel.cumsum import *
|
|
8
|
+
from ..jax_kernel.wy_fast_fwd import *
|
|
9
|
+
from ..jax_kernel.wy_fast_bwd import *
|