rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,305 @@
1
+ import os
2
+ import torch
3
+ from torch.utils.cpp_extension import load
4
+
5
+ kernel_dir_name = "torch_kernel"
6
+
7
+ use_rocm = "RWKV_USE_ROCM" in os.environ and os.environ["RWKV_USE_ROCM"] == "1"
8
+
9
+
10
+ class RWKVKernelOperator:
11
+ def __init__(self, head_size, max_sequence_length):
12
+ current_dir = os.path.dirname(__file__)
13
+ # current_dir = os.pat
14
+ if use_rocm:
15
+ wkv6_cuda = load(
16
+ name="wkv6",
17
+ sources=[
18
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
19
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
20
+ ],
21
+ # verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
22
+ verbose=True,
23
+ extra_cuda_cflags=[
24
+ "-fopenmp -ffast-math -munsafe-fp-atomics --gpu-max-threads-per-block=120 -enable-vectorize-compares",
25
+ f"-D_N_={head_size}",
26
+ f"-D_T_={max_sequence_length}",
27
+ ],
28
+ )
29
+ else:
30
+ wkv6_cuda = load(
31
+ name="wkv6",
32
+ sources=[
33
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_op.cpp"),
34
+ os.path.join(current_dir, f"{kernel_dir_name}/wkv6_cuda.cu"),
35
+ ],
36
+ # verbose=True, extra_cuda_cflags=[f"-D_N_={head_size}", f"-D_T_={max_sequence_length}"])
37
+ verbose=True,
38
+ extra_cuda_cflags=[
39
+ "-res-usage",
40
+ "--use_fast_math",
41
+ "-O3",
42
+ "-Xptxas -O3",
43
+ "--extra-device-vectorization",
44
+ f"-D_N_={head_size}",
45
+ f"-D_T_={max_sequence_length}",
46
+ ],
47
+ )
48
+
49
+ class RWKV_6(torch.autograd.Function):
50
+ @staticmethod
51
+ def forward(ctx, B, T, C, H, r, k, v, w, u):
52
+ if not isinstance(u, torch.Tensor):
53
+ u = u.value
54
+ with torch.no_grad():
55
+ assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
56
+ assert r.dtype in [torch.float32, torch.bfloat16, torch.float16]
57
+
58
+ assert head_size == C // H
59
+ ctx.B = B
60
+ ctx.T = T
61
+ ctx.C = C
62
+ ctx.H = H
63
+ assert r.is_contiguous()
64
+ assert k.is_contiguous()
65
+ assert v.is_contiguous()
66
+ assert w.is_contiguous()
67
+ assert u.is_contiguous()
68
+ ctx.save_for_backward(r, k, v, w, u)
69
+
70
+ y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
71
+
72
+ y = torch.empty(
73
+ (B, T, C),
74
+ device=r.device,
75
+ dtype=y_dtype,
76
+ memory_format=torch.contiguous_format,
77
+ ) # .uniform_(-100, 100)
78
+
79
+ if r.dtype == torch.float32:
80
+ wkv6_cuda.forward_fp32(B, T, C, H, r, k, v, w, u, y)
81
+ elif r.dtype == torch.bfloat16:
82
+ wkv6_cuda.forward_bf16(B, T, C, H, r, k, v, w, u, y)
83
+ else:
84
+ wkv6_cuda.forward_fp16(B, T, C, H, r, k, v, w, u, y)
85
+ return y
86
+
87
+ @staticmethod
88
+ def backward(ctx, gy):
89
+ assert gy.is_cuda
90
+ with torch.no_grad():
91
+ assert gy.dtype in [torch.bfloat16, torch.float32]
92
+ B = ctx.B
93
+ T = ctx.T
94
+ C = ctx.C
95
+ H = ctx.H
96
+ assert gy.is_contiguous()
97
+ r, k, v, w, u = ctx.saved_tensors
98
+ y_dtype = r.dtype if r.dtype != torch.float16 else torch.float32
99
+
100
+ gr = torch.empty(
101
+ (B, T, C),
102
+ device=gy.device,
103
+ requires_grad=False,
104
+ dtype=y_dtype,
105
+ memory_format=torch.contiguous_format,
106
+ ) # .uniform_(-100, 100)
107
+ gk = torch.empty(
108
+ (B, T, C),
109
+ device=gy.device,
110
+ requires_grad=False,
111
+ dtype=y_dtype,
112
+ memory_format=torch.contiguous_format,
113
+ ) # .uniform_(-100, 100)
114
+ gv = torch.empty(
115
+ (B, T, C),
116
+ device=gy.device,
117
+ requires_grad=False,
118
+ dtype=y_dtype,
119
+ memory_format=torch.contiguous_format,
120
+ ) # .uniform_(-100, 100)
121
+ gw = torch.empty(
122
+ (B, T, C),
123
+ device=gy.device,
124
+ requires_grad=False,
125
+ dtype=y_dtype,
126
+ memory_format=torch.contiguous_format,
127
+ ) # .uniform_(-100, 100)
128
+ gu = torch.empty(
129
+ (B, C),
130
+ device=gy.device,
131
+ requires_grad=False,
132
+ dtype=y_dtype,
133
+ memory_format=torch.contiguous_format,
134
+ ) # .uniform_(-100, 100)
135
+
136
+ if r.dtype == torch.float32:
137
+ wkv6_cuda.backward_fp32(
138
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
139
+ )
140
+ elif r.dtype == torch.bfloat16:
141
+ wkv6_cuda.backward_bf16(
142
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
143
+ )
144
+ else:
145
+ wkv6_cuda.backward_fp16(
146
+ B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gw, gu
147
+ )
148
+
149
+ gu = torch.sum(gu, 0).view(H, C // H)
150
+
151
+ return (None, None, None, None, gr, gk, gv, gw, gu)
152
+
153
+ class RWKV_6_with_state:
154
+ @staticmethod
155
+ def apply(B, T, C, H, S, s_map, r, k, v, w, u, s):
156
+ with torch.no_grad():
157
+ assert s_map.dtype == torch.int64, (
158
+ "s_map 必须为None 或者是长度为B的,int64类型的数组。"
159
+ )
160
+ assert (s is None and s_map is None) or (
161
+ s is not None and s_map is not None
162
+ ), "init_state与s_map必须同时为None 或者同时不为None"
163
+ assert (
164
+ r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
165
+ and r.dtype in [torch.float16, torch.float32, torch.bfloat16]
166
+ ), " r, k, v, w, u 必须为fp16 fp32 bf16中的一种 并且类型相同"
167
+ if r.dtype in [torch.float32, torch.bfloat16]:
168
+ o_dtype = r.dtype
169
+ else:
170
+ o_dtype = torch.float32
171
+ assert (
172
+ r.device
173
+ == k.device
174
+ == v.device
175
+ == w.device
176
+ == u.device
177
+ == s.device
178
+ == s_map.device
179
+ ), "what kan i say? 请确保r k v w u s s_map在同一设备上,快去检查!"
180
+
181
+ y = torch.empty(
182
+ (B, T, C),
183
+ device=r.device,
184
+ dtype=o_dtype,
185
+ memory_format=torch.contiguous_format,
186
+ )
187
+ ys = torch.empty(
188
+ (B, H, head_size, head_size),
189
+ device=r.device,
190
+ dtype=o_dtype,
191
+ memory_format=torch.contiguous_format,
192
+ )
193
+ # print(ys)
194
+ if r.dtype == torch.bfloat16:
195
+ wkv6_cuda.forward_with_state_bf16(
196
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
197
+ )
198
+ elif r.dtype == torch.float32:
199
+ wkv6_cuda.forward_with_state_fp32(
200
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
201
+ )
202
+ else:
203
+ wkv6_cuda.forward_with_state_fp16(
204
+ B, T, C, H, S, s_map, r, k, v, w, u, s, y, ys
205
+ )
206
+
207
+ return y, ys
208
+
209
+ self.head_size = head_size
210
+ self.normal_kernenl = RWKV_6
211
+ self.kernel_with_state = RWKV_6_with_state
212
+
213
+ def __call__(
214
+ self, r, k, v, w, u, with_state=False, init_state=None, state_map=None
215
+ ):
216
+ B, T, C = r.shape
217
+ assert C % self.head_size == 0
218
+ H = C // self.head_size
219
+ if not isinstance(u, torch.Tensor):
220
+ u = u.value
221
+
222
+ assert r.is_cuda
223
+ assert k.is_cuda
224
+ assert v.is_cuda
225
+ assert w.is_cuda
226
+ assert u.is_cuda
227
+
228
+ if isinstance(r, torch.Tensor):
229
+ assert r.device == k.device == v.device == w.device == u.device
230
+ else:
231
+ r.get_device() == k.get_device() == v.get_device() == w.get_device() == u.get_device()
232
+
233
+ assert r.dtype == k.dtype == v.dtype == w.dtype == u.dtype
234
+
235
+ if r.dtype in [torch.float32, torch.bfloat16]:
236
+ s_dtype = r.dtype
237
+ else:
238
+ s_dtype = torch.float32
239
+
240
+ is_custom_init = init_state is not None
241
+
242
+ if init_state is not None:
243
+ assert len(init_state.shape) in [3, 4], (
244
+ "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
245
+ )
246
+ if len(init_state.shape) == 3:
247
+ init_state = init_state[None, :]
248
+ assert (
249
+ init_state.shape[1:] == (H, self.head_size, self.head_size)
250
+ and init_state.shape[0] <= B
251
+ ), (
252
+ "init_state 的形状必须为(state_kinds /*<= Batch_size*/,num_heads,head_size,head_size) 或者(num_heads,head_size,head_size)"
253
+ )
254
+
255
+ assert init_state.dtype == s_dtype, f"init_state的数值类型应为: {s_dtype}"
256
+ assert init_state.device == r.device
257
+
258
+ if state_map is not None:
259
+ if isinstance(state_map, list):
260
+ state_map = torch.tensor(state_map, dtype=torch.int64)
261
+ elif isinstance(state_map, torch.Tensor):
262
+ assert state_map.dtype in [
263
+ torch.int32,
264
+ torch.int64,
265
+ ], "state_map是一个长度为Batch_Size的int64类型的映射数组"
266
+ state_map = state_map.to(torch.int64)
267
+ assert state_map.shape == (B,), "state_map的shape必须为(Batch_Size,)"
268
+ assert state_map.device == r.deivec
269
+
270
+ if with_state:
271
+ if init_state is None:
272
+ assert state_map is None, (
273
+ "您必须在指定了init_state的情况下才能使用state_map"
274
+ )
275
+ init_state = torch.zeros((0,), device=r.device, dtype=s_dtype)
276
+ state_map = torch.zeros((0,), device=r.device, dtype=torch.int64)
277
+ else:
278
+ n_state = init_state.shape[0]
279
+ if state_map is None:
280
+ assert n_state == 1 or n_state == B, (
281
+ "我无法为您推断state_map的形状,请手动指定。"
282
+ )
283
+ if n_state == 1:
284
+ state_map = torch.tensor(
285
+ [0] * B, dtype=torch.int64, device=r.device
286
+ )
287
+ elif n_state == B:
288
+ state_map = torch.tensor(
289
+ [i for i in range(B)], dtype=torch.int64, device=r.device
290
+ )
291
+ else:
292
+ assert False, "未实现"
293
+ else:
294
+ assert state_map.shape == (B,), "state_map的形状必须为(batch_size,)"
295
+ assert (state_map >= 0).all() and (state_map < n_state).all(), (
296
+ f"state_map的取值范围为[0,{n_state})之间的整数,您的输入显然不满足。"
297
+ )
298
+ # print('state map:',state_map)
299
+ o, ys = self.kernel_with_state.apply(
300
+ B, T, C, H, is_custom_init, state_map, r, k, v, w, u, init_state
301
+ )
302
+ return o, ys
303
+ else:
304
+ o = self.normal_kernenl.apply(B, T, C, H, r, k, v, w, u)
305
+ return o, None
@@ -0,0 +1,113 @@
1
+ import keras
2
+ from distutils.util import strtobool
3
+ import os
4
+ from keras import ops
5
+
6
+
7
+ def transpose_head(x, head_first):
8
+ if head_first:
9
+ return ops.transpose(x, (0, 2, 1, 3))
10
+ else:
11
+ return x
12
+
13
+
14
+ def get_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
15
+ assert HEAD_SIZE % 4 == 0
16
+ from .native_keras_op import generalized_delta_rule as native_op
17
+
18
+ if keras.config.backend() == "torch":
19
+ import torch
20
+
21
+ if torch.cuda.is_available():
22
+ if KERNEL_TYPE.lower() == "triton":
23
+ from .torch_op import generalized_delta_rule
24
+
25
+ return generalized_delta_rule, generalized_delta_rule, True
26
+ elif KERNEL_TYPE.lower() == "cuda":
27
+ from .torch_cuda_kernel.wkv7_torch import (
28
+ get_torch_generalized_delta_rule,
29
+ )
30
+
31
+ return get_torch_generalized_delta_rule(HEAD_SIZE) + [False]
32
+ elif keras.config.backend() == "jax":
33
+ import jax
34
+ import os
35
+
36
+ if jax.devices()[0].platform == "gpu":
37
+ if KERNEL_TYPE.lower() == "triton":
38
+ os.environ["JAX_LOG_COMPUTATION"] = "0"
39
+ from .jax_op import generalized_delta_rule
40
+
41
+ return generalized_delta_rule, native_op, False
42
+ elif KERNEL_TYPE.lower() == "cuda":
43
+ from .jax_cuda_kernel.wkv7_jax import get_jax_generalized_delta_rule
44
+
45
+ return get_jax_generalized_delta_rule(HEAD_SIZE) + [False]
46
+ elif keras.config.backend() == "tensorflow":
47
+ import tensorflow as tf
48
+
49
+ if len(tf.config.list_physical_devices("GPU")) > 0:
50
+ if KERNEL_TYPE.lower() == "cuda" and HEAD_SIZE:
51
+ try:
52
+ from jax.lib import xla_bridge
53
+
54
+ assert xla_bridge.get_backend().platform == "gpu"
55
+ except:
56
+ raise (
57
+ "The operation of the TensorFlow kernel depends on the JAX kernel."
58
+ "Therefore, it is necessary to ensure that it can be used in JAX, so that TensorFlow can be used."
59
+ )
60
+ print("🎉" * 10)
61
+ print("Tensorflow CUDA kernel onlt support Forward,not get graident")
62
+ print("🎉" * 10)
63
+ from .tf_eager_kernel import get_tf_generalized_delta_rule
64
+
65
+ generalized_delta_rule_inference = get_tf_generalized_delta_rule(
66
+ HEAD_SIZE
67
+ )
68
+ return native_op, generalized_delta_rule_inference, False
69
+ elif keras.config.backend() == "mlx" and KERNEL_TYPE.lower() == "cuda":
70
+ from .mlx_op import generalized_delta_rule
71
+
72
+ return native_op, generalized_delta_rule, False
73
+ return native_op, native_op, False
74
+
75
+
76
+ def get_rnn_generalized_delta_rule(HEAD_SIZE=64, KERNEL_TYPE="native"):
77
+ assert HEAD_SIZE % 4 == 0
78
+ from .native_keras_op import generalized_delta_rule
79
+
80
+ if KERNEL_TYPE == "cuda":
81
+ if keras.config.backend() == "jax":
82
+ import jax
83
+
84
+ if jax.devices()[0].platform == "gpu":
85
+ from .jax_cuda_kernel_single.wkv7_single_step_jax import (
86
+ get_jax_generalized_delta_rule_single_step,
87
+ )
88
+
89
+ return get_jax_generalized_delta_rule_single_step(HEAD_SIZE)
90
+ elif keras.config.backend() == "torch":
91
+ import torch
92
+
93
+ if torch.cuda.is_available():
94
+ from .torch_cuda_kernel_single.wkv7_single_step_torch import (
95
+ get_torch_generalized_delta_rule_single_step,
96
+ )
97
+
98
+ return get_torch_generalized_delta_rule_single_step(HEAD_SIZE)
99
+ elif keras.config.backend() == "tensorflow":
100
+ import tensorflow as tf
101
+
102
+ if len(tf.config.list_physical_devices("GPU")) > 0:
103
+ try:
104
+ import jax
105
+ except ImportError:
106
+ return generalized_delta_rule
107
+ if jax.devices()[0].platform == "gpu":
108
+ from .tf_eager_kernel import (
109
+ get_tf_generalized_delta_rule_single_step,
110
+ )
111
+
112
+ return get_tf_generalized_delta_rule_single_step(HEAD_SIZE)
113
+ return generalized_delta_rule
@@ -0,0 +1,220 @@
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Literal
4
+ import functools
5
+ import triton
6
+ import jax
7
+ import jax.numpy as jnp
8
+ from enum import Enum
9
+ import contextlib
10
+
11
+
12
+ @lru_cache(maxsize=None)
13
+ def get_multiprocessor_count(tensor_idx: int = 0) -> int:
14
+ return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
15
+ "multiprocessor_count"
16
+ ]
17
+
18
+
19
+ @lru_cache(maxsize=None)
20
+ def get_available_device() -> str:
21
+ try:
22
+ return triton.runtime.driver.active.get_current_target().backend
23
+ except BaseException:
24
+ import warnings
25
+
26
+ warnings.warn(
27
+ ("Triton is not supported on current platform, roll back to CPU."),
28
+ stacklevel=1,
29
+ )
30
+ return "cpu"
31
+
32
+
33
+ @lru_cache(maxsize=None)
34
+ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
35
+ device = get_available_device()
36
+ if device == "cuda":
37
+ return "nvidia"
38
+ elif device == "hip":
39
+ return "amd"
40
+ elif device == "xpu":
41
+ return "intel"
42
+ else:
43
+ return device
44
+
45
+
46
+ # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
47
+ # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
48
+ # Therefore, we need to check the triton backend to determine the actual GPU vendor.
49
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
50
+
51
+ device_platform = _check_platform()
52
+
53
+ is_intel = device_platform == "intel"
54
+ is_nvidia = device_platform == "nvidia"
55
+ is_amd = device_platform == "amd"
56
+
57
+ use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
58
+
59
+
60
+ device = get_available_device() if get_available_device() != "hip" else "cuda"
61
+
62
+ is_intel_a770 = False
63
+ device = jax.devices()
64
+ is_tf32_supported = is_nvidia
65
+
66
+
67
+ def get_all_max_shared_memory():
68
+ return [
69
+ triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
70
+ for i in range(len(jax.devices()))
71
+ ]
72
+
73
+
74
+ device_shared_mem_list = get_all_max_shared_memory()
75
+
76
+
77
+ @lru_cache(maxsize=None)
78
+ def is_triton_shared_mem_enough(
79
+ max_shared_mem: int = 102400, tensor_idx: int = 0
80
+ ) -> bool:
81
+ max_shared_memory = device_shared_mem_list[tensor_idx]
82
+ return max_shared_memory >= max_shared_mem
83
+
84
+
85
+ device_capacity = is_triton_shared_mem_enough()
86
+
87
+
88
+ def _cpu_device_warning():
89
+ import warnings
90
+
91
+ warnings.warn(
92
+ ("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
93
+ )
94
+
95
+
96
+ def get_all_max_shared_mem():
97
+ try:
98
+ return [
99
+ triton.runtime.driver.active.utils.get_device_properties(i)[
100
+ "max_shared_mem"
101
+ ]
102
+ for i in range(len(jax.devices()))
103
+ ]
104
+ except BaseException:
105
+ _cpu_device_warning()
106
+ return [-1]
107
+
108
+
109
+ class Backend(Enum):
110
+ ADA = 101376 # RTX 4090
111
+ AMPERE = 166912 # A100
112
+ HOPPER = 232448 # H100
113
+ DEFAULT = 102400 # Default
114
+
115
+ @classmethod
116
+ def get_shared_memory(cls, arch: str) -> int:
117
+ try:
118
+ return cls[arch.upper()].value
119
+ except KeyError:
120
+ return cls.DEFAULT.value
121
+
122
+
123
+ @lru_cache(maxsize=None)
124
+ def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
125
+ try:
126
+ device_shared_mem_list = get_all_max_shared_mem()
127
+ max_shared_memory = device_shared_mem_list[tensor_idx]
128
+ return max_shared_memory >= Backend.get_shared_memory(arch)
129
+ except Exception:
130
+ return False
131
+
132
+
133
+ def tensor_cache(fn):
134
+ """
135
+ A decorator that caches the most recent result of a function with tensor inputs.
136
+
137
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
138
+ If the function is called again with the same input tensors, it will return the cached result.
139
+
140
+
141
+ Args:
142
+ fn (Callable[..., jax.Array]):
143
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
144
+
145
+ Returns:
146
+ Callable[..., jax.Array]:
147
+ A wrapped version of the input function with single-entry caching.
148
+ """
149
+ last_args = None
150
+ last_kwargs = None
151
+ last_result = None
152
+
153
+ @functools.wraps(fn)
154
+ def wrapper(*args, **kwargs):
155
+ nonlocal last_args, last_kwargs, last_result
156
+
157
+ if last_args is not None and last_kwargs is not None:
158
+ if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
159
+ if all(a is b for a, b in zip(args, last_args)) and all(
160
+ k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
161
+ ):
162
+ return last_result
163
+
164
+ result = fn(*args, **kwargs)
165
+ last_args, last_kwargs, last_result = args, kwargs, result
166
+ return result
167
+
168
+ return wrapper
169
+
170
+
171
+ @tensor_cache
172
+ def prepare_lens(cu_seqlens):
173
+ return cu_seqlens[1:] - cu_seqlens[:-1]
174
+
175
+
176
+ @tensor_cache
177
+ def prepare_chunk_indices(cu_seqlens, chunk_size: int):
178
+ indices = jnp.concatenate(
179
+ [
180
+ jnp.arange(n)
181
+ for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
182
+ ]
183
+ )
184
+
185
+ return jnp.stack([jnp.cumsum(jnp.equal(indices, 0), 0) - 1, indices], 1)
186
+
187
+
188
+ def input_guard(fn):
189
+ """
190
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
191
+ """
192
+
193
+ @functools.wraps(fn)
194
+ def wrapper(*args, **kwargs):
195
+ contiguous_args = (i for i in args)
196
+ contiguous_kwargs = {k: v for k, v in kwargs.items()}
197
+
198
+ tensor = None
199
+ for arg in args:
200
+ if isinstance(arg, jax.Array):
201
+ tensor = arg
202
+ break
203
+ if tensor is None:
204
+ for value in kwargs.values():
205
+ if isinstance(value, jax.Array):
206
+ tensor = value
207
+ break
208
+
209
+ if tensor is not None:
210
+ ctx = tensor.device
211
+ else:
212
+ ctx = contextlib.nullcontext()
213
+
214
+ with ctx:
215
+ return fn(*contiguous_args, **contiguous_kwargs)
216
+
217
+ return wrapper
218
+
219
+
220
+ is_intel_alchemist = False