rwkv-ops 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (43) hide show
  1. rwkv_ops/__init__.py +26 -0
  2. rwkv_ops/rwkv7_kernel/__init__.py +153 -0
  3. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
  4. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  5. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  6. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  7. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  8. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  9. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  10. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  11. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  12. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  13. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  14. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  15. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  16. rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
  17. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  18. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  19. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  20. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  21. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  22. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  23. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  24. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  25. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  26. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  27. rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
  28. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  29. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  30. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  31. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  32. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  33. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  34. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  35. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  36. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  37. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  38. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  39. rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
  40. rwkv_ops-0.1.0.dist-info/METADATA +118 -0
  41. rwkv_ops-0.1.0.dist-info/RECORD +43 -0
  42. rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
  43. rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
rwkv_ops/__init__.py ADDED
@@ -0,0 +1,26 @@
1
+ import os
2
+
3
+ KERNEL_TYPE = os.environ.get("KERNEL_TYPE", "triton")
4
+ KERAS_BACKEND = os.environ.get("KERAS_BACKEND")
5
+ BACKEND = os.environ.get("KERNEL_BACKEND")
6
+
7
+
8
+ if KERAS_BACKEND is not None:
9
+ BACKEND = KERAS_BACKEND
10
+ elif BACKEND is not None:
11
+ os.environ["KERAS_BACKEND"] = BACKEND
12
+ else:
13
+ import torch
14
+ import keras
15
+
16
+ BACKEND = "torch"
17
+ os.environ["KERAS_BACKEND"] = BACKEND
18
+ keras.config.set_backend("torch")
19
+ assert KERNEL_TYPE in ["triton", "cuda", "native"]
20
+ assert BACKEND in ["torch", "jax", "numpy", "tensorflow"]
21
+ from .rwkv7_kernel import get_generalized_delta_rule
22
+
23
+ generalized_delta_rule, RWKV7_USE_KERNEL = get_generalized_delta_rule(
24
+ KERNEL_TYPE=KERNEL_TYPE
25
+ )
26
+ rwkv7_op = generalized_delta_rule
@@ -0,0 +1,153 @@
1
+ import keras
2
+ from distutils.util import strtobool
3
+ from keras import ops
4
+
5
+
6
+ def transpose_head(x, head_first):
7
+ if head_first:
8
+ return ops.transpose(x, (0, 2, 1, 3))
9
+ else:
10
+ return x
11
+
12
+
13
+ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
14
+ USE_KERNEL = False
15
+ if keras.config.backend() == "torch":
16
+ import torch
17
+
18
+ if KERNEL_TYPE.lower() == "triton":
19
+ from .torch_op import generalized_delta_rule
20
+
21
+ USE_KERNEL = True
22
+ elif KERNEL_TYPE.lower() == "cuda":
23
+ CHUNK_LEN = 16
24
+ USE_KERNEL = True
25
+ from torch.utils.cpp_extension import load
26
+ import os
27
+
28
+ flags = [
29
+ "-res-usage",
30
+ f"-D_C_={HEAD_SIZE}",
31
+ f"-D_CHUNK_LEN_={CHUNK_LEN}",
32
+ "--use_fast_math",
33
+ "-O3",
34
+ "-Xptxas -O3",
35
+ "--extra-device-vectorization",
36
+ ]
37
+ # 获取当前文件的绝对路径
38
+ current_file_path = os.path.abspath(__file__)
39
+
40
+ # 获取当前文件的目录路径
41
+ current_dir_path = os.path.dirname(current_file_path)
42
+
43
+ # 获取上一级目录的路径
44
+ parent_dir_path = os.path.abspath(
45
+ os.path.join(current_dir_path, os.path.pardir)
46
+ )
47
+ load(
48
+ name="wind_backstepping",
49
+ sources=[
50
+ os.path.join(parent_dir_path, "cuda_kernel/wkv7_cuda.cu"),
51
+ os.path.join(parent_dir_path, "cuda_kernel/wkv7_op.cpp"),
52
+ ],
53
+ is_python_module=False,
54
+ verbose=True,
55
+ extra_cuda_cflags=flags,
56
+ )
57
+
58
+ class WindBackstepping(torch.autograd.Function):
59
+ @staticmethod
60
+ def forward(ctx, w, q, k, v, z, b):
61
+ B, T, H, C = w.shape
62
+ DTYPE = q.dtype
63
+ q = ops.cast(q, "bfloat16")
64
+ k = ops.cast(k, "bfloat16")
65
+ v = ops.cast(v, "bfloat16")
66
+ z = ops.cast(z, "bfloat16")
67
+ b = ops.cast(b, "bfloat16")
68
+ w = ops.cast(w, "bfloat16")
69
+ assert T % CHUNK_LEN == 0
70
+ assert all(i.is_contiguous() for i in [w, q, k, v, z, b])
71
+ y = torch.empty_like(v)
72
+ s = torch.empty(
73
+ B, H, T // CHUNK_LEN, C, C, dtype=torch.float32, device=w.device
74
+ )
75
+ sa = torch.empty(B, T, H, C, dtype=torch.float32, device=w.device)
76
+ torch.ops.wind_backstepping.forward(w, q, k, v, z, b, y, s, sa)
77
+ ctx.save_for_backward(w, q, k, v, z, b, s, sa)
78
+ return ops.cast(y, DTYPE)
79
+
80
+ @staticmethod
81
+ def backward(ctx, dy):
82
+ DTYPE = dy.dtype
83
+ dy = ops.cast(dy, torch.bfloat16)
84
+ dy = dy.contiguous()
85
+ assert all(i.dtype == torch.bfloat16 for i in [dy])
86
+ assert all(i.is_contiguous() for i in [dy])
87
+ w, q, k, v, z, b, s, sa = ctx.saved_tensors
88
+ dw, dq, dk, dv, dz, db = [
89
+ torch.empty_like(x) for x in [w, q, k, v, z, b]
90
+ ]
91
+ torch.ops.wind_backstepping.backward(
92
+ w, q, k, v, z, b, dy, s, sa, dw, dq, dk, dv, dz, db
93
+ )
94
+ return (
95
+ ops.cast(dw, DTYPE),
96
+ ops.cast(dq, DTYPE),
97
+ ops.cast(dk, DTYPE),
98
+ ops.cast(dv, DTYPE),
99
+ ops.cast(dz, DTYPE),
100
+ ops.cast(db, DTYPE),
101
+ )
102
+
103
+ def RUN_CUDA_RWKV7g(q, w, k, v, a, b):
104
+ B, T, H, C = q.shape
105
+ q = q.contiguous()
106
+ w = w.contiguous()
107
+ k = k.contiguous()
108
+ v = v.contiguous()
109
+ a = a.contiguous()
110
+ b = b.contiguous()
111
+ return WindBackstepping.apply(w, q, k, v, a, b).view(B, T, H * C)
112
+
113
+ def generalized_delta_rule(
114
+ r: torch.Tensor,
115
+ w: torch.Tensor,
116
+ k: torch.Tensor,
117
+ v: torch.Tensor,
118
+ a: torch.Tensor,
119
+ b: torch.Tensor,
120
+ initial_state: torch.Tensor = None,
121
+ output_final_state: bool = True,
122
+ head_first: bool = False,
123
+ use_chunk: bool = True,
124
+ ):
125
+ r = transpose_head(r, head_first)
126
+ k = transpose_head(k, head_first)
127
+ v = transpose_head(v, head_first)
128
+ a = transpose_head(a, head_first)
129
+ b = transpose_head(b, head_first)
130
+ w = transpose_head(w, head_first)
131
+ return RUN_CUDA_RWKV7g(r, w, k, v, a, b), None
132
+ else:
133
+ from .native_keras_op import generalized_delta_rule
134
+
135
+ USE_KERNEL = False
136
+ elif keras.config.backend() == "jax":
137
+ from jax.lib import xla_bridge
138
+ import jax
139
+ import os
140
+
141
+ if (
142
+ xla_bridge.get_backend().platform == "gpu"
143
+ and KERNEL_TYPE.lower() == "triton"
144
+ ):
145
+ from .jax_op import generalized_delta_rule
146
+
147
+ USE_KERNEL = True
148
+ else:
149
+ from .native_keras_op import generalized_delta_rule
150
+
151
+ else:
152
+ from .native_keras_op import generalized_delta_rule
153
+ return generalized_delta_rule, USE_KERNEL
@@ -0,0 +1,221 @@
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Literal
4
+ import functools
5
+ import triton
6
+ import jax
7
+ import jax.numpy as jnp
8
+
9
+
10
+ @lru_cache(maxsize=None)
11
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
12
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
13
+ "multiprocessor_count"
14
+ ]
15
+
16
+
17
+ @lru_cache(maxsize=None)
18
+ def get_available_device() -> str:
19
+ try:
20
+ return triton.runtime.driver.active.get_current_target().backend
21
+ except BaseException:
22
+ import warnings
23
+
24
+ warnings.warn(
25
+ ("Triton is not supported on current platform, roll back to CPU."),
26
+ stacklevel=1,
27
+ )
28
+ return "cpu"
29
+
30
+
31
+ @lru_cache(maxsize=None)
32
+ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
33
+ device = get_available_device()
34
+ if device == "cuda":
35
+ return "nvidia"
36
+ elif device == "hip":
37
+ return "amd"
38
+ elif device == "xpu":
39
+ return "intel"
40
+ else:
41
+ return device
42
+
43
+
44
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
45
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
46
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
47
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
48
+
49
+ device_platform = _check_platform()
50
+
51
+ is_intel = device_platform == "intel"
52
+ is_nvidia = device_platform == "nvidia"
53
+ is_amd = device_platform == "amd"
54
+
55
+ use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
56
+
57
+
58
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
59
+
60
+ is_intel_a770 = False
61
+ device = jax.devices()
62
+ is_tf32_supported = is_nvidia
63
+
64
+
65
+ def get_all_max_shared_memory():
66
+ return [
67
+ triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
68
+ for i in range(len(jax.devices()))
69
+ ]
70
+
71
+
72
+ device_shared_mem_list = get_all_max_shared_memory()
73
+
74
+
75
+ @lru_cache(maxsize=None)
76
+ def is_triton_shared_mem_enough(
77
+ max_shared_mem: int = 102400, tensor_idx: int = 0
78
+ ) -> bool:
79
+ max_shared_memory = device_shared_mem_list[tensor_idx]
80
+ return max_shared_memory >= max_shared_mem
81
+
82
+
83
+ device_capacity = is_triton_shared_mem_enough()
84
+
85
+ from enum import Enum
86
+ import contextlib
87
+
88
+
89
+ def _cpu_device_warning():
90
+ import warnings
91
+
92
+ warnings.warn(
93
+ ("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
94
+ )
95
+
96
+
97
+ def get_all_max_shared_mem():
98
+ try:
99
+ return [
100
+ triton.runtime.driver.active.utils.get_device_properties(i)[
101
+ "max_shared_mem"
102
+ ]
103
+ for i in range(len(jax.devices()))
104
+ ]
105
+ except BaseException:
106
+ _cpu_device_warning()
107
+ return [-1]
108
+
109
+
110
+ class Backend(Enum):
111
+ ADA = 101376 # RTX 4090
112
+ AMPERE = 166912 # A100
113
+ HOPPER = 232448 # H100
114
+ DEFAULT = 102400 # Default
115
+
116
+ @classmethod
117
+ def get_shared_memory(cls, arch: str) -> int:
118
+ try:
119
+ return cls[arch.upper()].value
120
+ except KeyError:
121
+ return cls.DEFAULT.value
122
+
123
+
124
+ @lru_cache(maxsize=None)
125
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
126
+ try:
127
+ device_shared_mem_list = get_all_max_shared_mem()
128
+ max_shared_memory = device_shared_mem_list[tensor_idx]
129
+ return max_shared_memory >= Backend.get_shared_memory(arch)
130
+ except Exception:
131
+ return False
132
+
133
+
134
+ def tensor_cache(fn):
135
+ """
136
+ A decorator that caches the most recent result of a function with tensor inputs.
137
+
138
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
139
+ If the function is called again with the same input tensors, it will return the cached result.
140
+
141
+
142
+ Args:
143
+ fn (Callable[..., jax.Array]):
144
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
145
+
146
+ Returns:
147
+ Callable[..., jax.Array]:
148
+ A wrapped version of the input function with single-entry caching.
149
+ """
150
+ last_args = None
151
+ last_kwargs = None
152
+ last_result = None
153
+
154
+ @functools.wraps(fn)
155
+ def wrapper(*args, **kwargs):
156
+ nonlocal last_args, last_kwargs, last_result
157
+
158
+ if last_args is not None and last_kwargs is not None:
159
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
160
+ if all(a is b for a, b in zip(args, last_args)) and all(
161
+ k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
162
+ ):
163
+ return last_result
164
+
165
+ result = fn(*args, **kwargs)
166
+ last_args, last_kwargs, last_result = args, kwargs, result
167
+ return result
168
+
169
+ return wrapper
170
+
171
+
172
+ @tensor_cache
173
+ def prepare_lens(cu_seqlens):
174
+ return cu_seqlens[1:] - cu_seqlens[:-1]
175
+
176
+
177
+ @tensor_cache
178
+ def prepare_chunk_indices(cu_seqlens, chunk_size: int):
179
+ indices = jnp.concatenate(
180
+ [
181
+ jnp.arange(n)
182
+ for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
183
+ ]
184
+ )
185
+
186
+ return jnp.stack([jnp.cumsum(jnp.equal(indices, 0), 0) - 1, indices], 1)
187
+
188
+
189
+ def input_guard(fn):
190
+ """
191
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
192
+ """
193
+
194
+ @functools.wraps(fn)
195
+ def wrapper(*args, **kwargs):
196
+ contiguous_args = (i for i in args)
197
+ contiguous_kwargs = {k: v for k, v in kwargs.items()}
198
+
199
+ tensor = None
200
+ for arg in args:
201
+ if isinstance(arg, jax.Array):
202
+ tensor = arg
203
+ break
204
+ if tensor is None:
205
+ for value in kwargs.values():
206
+ if isinstance(value, jax.Array):
207
+ tensor = value
208
+ break
209
+
210
+ if tensor is not None:
211
+ ctx = tensor.device
212
+ else:
213
+ ctx = contextlib.nullcontext()
214
+
215
+ with ctx:
216
+ return fn(*contiguous_args, **contiguous_kwargs)
217
+
218
+ return wrapper
219
+
220
+
221
+ is_intel_alchemist = False
@@ -0,0 +1,250 @@
1
+ import functools
2
+ import os
3
+ from functools import lru_cache
4
+ from typing import Literal
5
+
6
+ import triton
7
+ from packaging import version
8
+ import torch
9
+
10
+
11
+ @lru_cache(maxsize=None)
12
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
13
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
14
+ "multiprocessor_count"
15
+ ]
16
+
17
+
18
+ @lru_cache(maxsize=None)
19
+ def get_available_device() -> str:
20
+ try:
21
+ return triton.runtime.driver.active.get_current_target().backend
22
+ except BaseException:
23
+ import warnings
24
+
25
+ warnings.warn(
26
+ ("Triton is not supported on current platform, roll back to CPU."),
27
+ stacklevel=1,
28
+ )
29
+ return "cpu"
30
+
31
+
32
+ @lru_cache(maxsize=None)
33
+ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
34
+ device = get_available_device()
35
+ if device == "cuda":
36
+ return "nvidia"
37
+ elif device == "hip":
38
+ return "amd"
39
+ elif device == "xpu":
40
+ return "intel"
41
+ else:
42
+ return device
43
+
44
+
45
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
46
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
47
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
48
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
49
+
50
+ device_platform = _check_platform()
51
+
52
+ is_intel = device_platform == "intel"
53
+ is_nvidia = device_platform == "nvidia"
54
+ is_amd = device_platform == "amd"
55
+
56
+ use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
57
+
58
+
59
+ @lru_cache(maxsize=None)
60
+ def check_pytorch_version(version_s: str = "2.4") -> bool:
61
+ return version.parse(torch.__version__) >= version.parse(version_s)
62
+
63
+
64
+ is_intel_a770 = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
65
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
66
+ device_torch_lib = getattr(torch, device)
67
+ if check_pytorch_version("2.4"):
68
+ device = "cuda" if device == "cpu" else device
69
+ autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
70
+ autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
71
+
72
+ def custom_device_ctx(index: int):
73
+ return device_torch_lib.device(index)
74
+ else:
75
+ assert device == "cuda", (
76
+ "Only cuda device is supported for PyTorch version < 2.4.0."
77
+ )
78
+ autocast_custom_fwd = device_torch_lib.amp.custom_fwd
79
+ autocast_custom_bwd = device_torch_lib.amp.custom_bwd
80
+
81
+ def custom_device_ctx(index: int):
82
+ return torch.cuda.device(index)
83
+
84
+
85
+ # Nvidia Ampere or newer, haven't check AMD and intel yet.
86
+ is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
87
+
88
+
89
+ def get_all_max_shared_memory():
90
+ return [
91
+ triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
92
+ for i in range(device_torch_lib.device_count())
93
+ ]
94
+
95
+
96
+ device_shared_mem_list = get_all_max_shared_memory()
97
+
98
+
99
+ @lru_cache(maxsize=None)
100
+ def is_triton_shared_mem_enough(
101
+ max_shared_mem: int = 102400, tensor_idx: int = 0
102
+ ) -> bool:
103
+ max_shared_memory = device_shared_mem_list[tensor_idx]
104
+ return max_shared_memory >= max_shared_mem
105
+
106
+
107
+ device_capacity = is_triton_shared_mem_enough()
108
+ from enum import Enum
109
+ import contextlib
110
+
111
+
112
+ def _cpu_device_warning():
113
+ import warnings
114
+
115
+ warnings.warn(
116
+ ("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
117
+ )
118
+
119
+
120
+ def get_all_max_shared_mem():
121
+ try:
122
+ return [
123
+ triton.runtime.driver.active.utils.get_device_properties(i)[
124
+ "max_shared_mem"
125
+ ]
126
+ for i in range(device_torch_lib.device_count())
127
+ ]
128
+ except BaseException:
129
+ _cpu_device_warning()
130
+ return [-1]
131
+
132
+
133
+ class Backend(Enum):
134
+ ADA = 101376 # RTX 4090
135
+ AMPERE = 166912 # A100
136
+ HOPPER = 232448 # H100
137
+ DEFAULT = 102400 # Default
138
+
139
+ @classmethod
140
+ def get_shared_memory(cls, arch: str) -> int:
141
+ try:
142
+ return cls[arch.upper()].value
143
+ except KeyError:
144
+ return cls.DEFAULT.value
145
+
146
+
147
+ @lru_cache(maxsize=None)
148
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
149
+ try:
150
+ device_shared_mem_list = get_all_max_shared_mem()
151
+ max_shared_memory = device_shared_mem_list[tensor_idx]
152
+ return max_shared_memory >= Backend.get_shared_memory(arch)
153
+ except Exception:
154
+ return False
155
+
156
+
157
+ def tensor_cache(fn):
158
+ """
159
+ A decorator that caches the most recent result of a function with tensor inputs.
160
+
161
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
162
+ If the function is called again with the same input tensors, it will return the cached result.
163
+
164
+
165
+ Args:
166
+ fn (Callable[..., torch.Tensor]):
167
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
168
+
169
+ Returns:
170
+ Callable[..., torch.Tensor]:
171
+ A wrapped version of the input function with single-entry caching.
172
+ """
173
+ last_args = None
174
+ last_kwargs = None
175
+ last_result = None
176
+
177
+ @functools.wraps(fn)
178
+ def wrapper(*args, **kwargs):
179
+ nonlocal last_args, last_kwargs, last_result
180
+
181
+ if last_args is not None and last_kwargs is not None:
182
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
183
+ if all(a is b for a, b in zip(args, last_args)) and all(
184
+ k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
185
+ ):
186
+ return last_result
187
+
188
+ result = fn(*args, **kwargs)
189
+ last_args, last_kwargs, last_result = args, kwargs, result
190
+ return result
191
+
192
+ return wrapper
193
+
194
+
195
+ @tensor_cache
196
+ def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
197
+ return cu_seqlens[1:] - cu_seqlens[:-1]
198
+
199
+
200
+ @tensor_cache
201
+ def prepare_chunk_indices(
202
+ cu_seqlens: torch.LongTensor, chunk_size: int
203
+ ) -> torch.LongTensor:
204
+ indices = torch.cat(
205
+ [
206
+ torch.arange(n)
207
+ for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
208
+ ]
209
+ )
210
+ return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
211
+
212
+
213
+ def input_guard(fn):
214
+ """
215
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
216
+ """
217
+
218
+ @functools.wraps(fn)
219
+ def wrapper(*args, **kwargs):
220
+ contiguous_args = (
221
+ i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
222
+ )
223
+ contiguous_kwargs = {
224
+ k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
225
+ for k, v in kwargs.items()
226
+ }
227
+
228
+ tensor = None
229
+ for arg in args:
230
+ if isinstance(arg, torch.Tensor):
231
+ tensor = arg
232
+ break
233
+ if tensor is None:
234
+ for value in kwargs.values():
235
+ if isinstance(value, torch.Tensor):
236
+ tensor = value
237
+ break
238
+
239
+ if tensor is not None:
240
+ ctx = custom_device_ctx(tensor.device.index)
241
+ else:
242
+ ctx = contextlib.nullcontext()
243
+
244
+ with ctx:
245
+ return fn(*contiguous_args, **contiguous_kwargs)
246
+
247
+ return wrapper
248
+
249
+
250
+ is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
@@ -0,0 +1,9 @@
1
+ from ..jax_kernel.chunk_A_fwd import *
2
+ from ..jax_kernel.chunk_A_bwd import *
3
+ from ..jax_kernel.chunk_h_fwd import *
4
+ from ..jax_kernel.chunk_h_bwd import *
5
+ from ..jax_kernel.chunk_o_fwd import *
6
+ from ..jax_kernel.chunk_o_bwd import *
7
+ from ..jax_kernel.cumsum import *
8
+ from ..jax_kernel.wy_fast_fwd import *
9
+ from ..jax_kernel.wy_fast_bwd import *