rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,504 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This file implements the forward and backward pass of a chunked delta rule attention mechanism,
4
+ optimized with Triton kernels for GPU acceleration. It includes functions for forward propagation,
5
+ backward gradient computation, and integration with PyTorch's autograd system.
6
+
7
+ 该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
8
+ 使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
9
+ 并集成了 PyTorch 的自动求导系统。
10
+
11
+ """
12
+
13
+ import warnings
14
+ from typing import Optional
15
+
16
+ import torch
17
+ import triton
18
+
19
+ from .torch_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
20
+ from .torch_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
21
+ from .torch_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
22
+ from .torch_kernel.chunk_h_fwd import chunk_dplr_fwd_h
23
+
24
+ from .torch_kernel.chunk_o_bwd import (
25
+ chunk_dplr_bwd_dAu,
26
+ chunk_dplr_bwd_dv,
27
+ chunk_dplr_bwd_o,
28
+ )
29
+ from .torch_kernel.chunk_o_fwd import chunk_dplr_fwd_o
30
+ from .torch_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
31
+ from .torch_kernel.wy_fast_fwd import prepare_wy_repr_fwd
32
+ from .torch_kernel.cumsum import chunk_rwkv6_fwd_cumsum
33
+ from .get_torch_devices_info import (
34
+ autocast_custom_bwd,
35
+ autocast_custom_fwd,
36
+ input_guard,
37
+ )
38
+
39
+
40
+ def cast(x, dtype):
41
+ if x is None or x.dtype == dtype:
42
+ return x
43
+ return x.to(dtype)
44
+
45
+
46
+ def chunk_dplr_fwd(
47
+ q: torch.Tensor,
48
+ k: torch.Tensor,
49
+ v: torch.Tensor,
50
+ a: torch.Tensor,
51
+ b: torch.Tensor,
52
+ gk: torch.Tensor,
53
+ scale: float = 1,
54
+ initial_state: torch.Tensor = None,
55
+ output_final_state: bool = True,
56
+ chunk_size: int = 16,
57
+ ):
58
+ """
59
+ Forward pass of chunked delta rule attention.
60
+
61
+ 分块 Delta Rule 注意力机制的前向传播。
62
+
63
+ Args:
64
+ q (torch.Tensor): Queries tensor [B, T, H, K]
65
+ k (torch.Tensor): Keys tensor [B, T, H, K]
66
+ v (torch.Tensor): Values tensor [B, T, H, V]
67
+ a (torch.Tensor): Activations tensor [B, T, H, K]
68
+ b (torch.Tensor): Betas tensor [B, T, H, K]
69
+ gk (torch.Tensor): Log decay tensor [B, T, H, K]
70
+ scale (float): Scale factor for attention scores
71
+ initial_state (Optional[torch.Tensor]): Initial state for recurrent processing
72
+ output_final_state (bool): Whether to return final state
73
+ chunk_size (int): Chunk size for processing
74
+
75
+ Returns:
76
+ o (torch.Tensor): Output tensor [B, T, H, V]
77
+ final_state (Optional[torch.Tensor]): Final state if requested
78
+ """
79
+ T = q.shape[1]
80
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
81
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
82
+
83
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
84
+ q=q,
85
+ k=k,
86
+ a=a,
87
+ b=b,
88
+ gi=gi,
89
+ ge=ge,
90
+ scale=scale,
91
+ chunk_size=BT,
92
+ )
93
+
94
+ del ge
95
+
96
+ # A_ab, A_ak, gi, ge torch.float32
97
+ # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16
98
+ w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
99
+
100
+ del A_ab, A_ak
101
+ h, v_new, final_state = chunk_dplr_fwd_h(
102
+ kg=kg,
103
+ bg=bg,
104
+ v=v,
105
+ w=w,
106
+ u=u,
107
+ gk=gi,
108
+ initial_state=initial_state,
109
+ output_final_state=output_final_state,
110
+ chunk_size=BT,
111
+ )
112
+
113
+ del u, kg, bg, gi
114
+
115
+ o = chunk_dplr_fwd_o(
116
+ qg=qg, v=v, v_new=v_new, A_qk=A_qk, A_qb=A_qb, h=h, chunk_size=BT
117
+ )
118
+ del v_new, h, A_qk, A_qb
119
+
120
+ return o, final_state
121
+
122
+
123
+ def chunk_dplr_bwd(
124
+ q: torch.Tensor,
125
+ k: torch.Tensor,
126
+ v: torch.Tensor,
127
+ a: torch.Tensor,
128
+ b: torch.Tensor,
129
+ gk: torch.Tensor,
130
+ initial_state: torch.Tensor,
131
+ scale,
132
+ do,
133
+ dht,
134
+ BT: int = 16,
135
+ ):
136
+ """
137
+ Backward pass of chunked delta rule attention.
138
+
139
+ 分块 Delta Rule 注意力机制的反向传播。
140
+
141
+ Args:
142
+ q (torch.Tensor): Queries tensor [B, T, H, K]
143
+ k (torch.Tensor): Keys tensor [B, T, H, K]
144
+ v (torch.Tensor): Values tensor [B, T, H, V]
145
+ a (torch.Tensor): Activations tensor [B, T, H, K]
146
+ b (torch.Tensor): Betas tensor [B, T, H, K]
147
+ gk (torch.Tensor): Log decay tensor [B, T, H, K]
148
+ initial_state (torch.Tensor): Initial state for recurrent processing
149
+ scale (float): Scale factor for attention scores
150
+ do (torch.Tensor): Gradient of outputs
151
+ dht (torch.Tensor): Gradient of final hidden state
152
+ BT (int): Chunk size for processing
153
+
154
+ Returns:
155
+ dq (torch.Tensor): Gradient of queries
156
+ dk (torch.Tensor): Gradient of keys
157
+ dv (torch.Tensor): Gradient of values
158
+ da (torch.Tensor): Gradient of activations
159
+ db (torch.Tensor): Gradient of betas
160
+ dgk (torch.Tensor): Gradient of log decays
161
+ dh0 (torch.Tensor): Gradient of initial state
162
+ """
163
+ # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted *******
164
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
165
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
166
+ q=q,
167
+ k=k,
168
+ a=a,
169
+ b=b,
170
+ gi=gi,
171
+ ge=ge,
172
+ scale=scale,
173
+ chunk_size=BT,
174
+ )
175
+ w, u, A_ab_inv = prepare_wy_repr_fwd(
176
+ ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT
177
+ )
178
+ del A_ab
179
+ h, v_new, _ = chunk_dplr_fwd_h(
180
+ kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
181
+ )
182
+ del u
183
+ # ******* end of recomputation *******
184
+ # A_ak, A_ab_inv, gi, ge torch.float32
185
+ # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16
186
+
187
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
188
+ v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
189
+ )
190
+
191
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
192
+ qg=qg,
193
+ bg=bg,
194
+ w=w,
195
+ gk=gi,
196
+ h0=initial_state,
197
+ dht=dht,
198
+ do=do,
199
+ dv=dv_new_intra,
200
+ chunk_size=BT,
201
+ )
202
+
203
+ dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
204
+ del A_qk
205
+
206
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
207
+ k=kg,
208
+ b=bg,
209
+ v=v,
210
+ v_new=v_new,
211
+ do=do,
212
+ h=h,
213
+ dh=dh,
214
+ dv=dv_new,
215
+ w=w,
216
+ gk=gi,
217
+ chunk_size=BT,
218
+ scale=scale,
219
+ )
220
+ del v_new
221
+
222
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
223
+ A_ab_inv=A_ab_inv,
224
+ A_ak=A_ak,
225
+ v=v,
226
+ ag=ag,
227
+ dw=dw,
228
+ du=dv_new,
229
+ dv0=dv,
230
+ chunk_size=BT,
231
+ )
232
+ del A_ak
233
+
234
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
235
+ q=q,
236
+ k=k,
237
+ a=a,
238
+ b=b,
239
+ gi=gi,
240
+ ge=ge,
241
+ dAqk=dA_qk,
242
+ dAqb=dA_qb,
243
+ dAak=dA_ak,
244
+ dAab=dA_ab,
245
+ dgk_last=dgk_last,
246
+ dqg=dqg,
247
+ dkg=dkg,
248
+ dag=dag,
249
+ dbg=dbg,
250
+ chunk_size=BT,
251
+ scale=scale,
252
+ )
253
+
254
+ return (
255
+ dq.to(q),
256
+ dk.to(k),
257
+ dv.to(v),
258
+ da.to(a),
259
+ db.to(b),
260
+ dgk.to(gk),
261
+ None,
262
+ dh0,
263
+ None,
264
+ None,
265
+ )
266
+
267
+
268
+ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
269
+ @staticmethod
270
+ @input_guard
271
+ @autocast_custom_fwd
272
+ def forward(
273
+ ctx,
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gk: torch.Tensor,
280
+ scale: float = 1,
281
+ initial_state: torch.Tensor = None,
282
+ output_final_state: bool = True,
283
+ cu_seqlens: Optional[torch.LongTensor] = None,
284
+ ):
285
+ chunk_size = 16
286
+ o, final_state = chunk_dplr_fwd(
287
+ q=q,
288
+ k=k,
289
+ v=v,
290
+ a=a,
291
+ b=b,
292
+ gk=gk,
293
+ scale=scale,
294
+ initial_state=initial_state,
295
+ output_final_state=output_final_state,
296
+ chunk_size=chunk_size,
297
+ )
298
+ ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
299
+ ctx.cu_seqlens = cu_seqlens
300
+ ctx.scale = scale
301
+ ctx.chunk_size = chunk_size
302
+ return o.to(q.dtype), final_state
303
+
304
+ @staticmethod
305
+ @input_guard
306
+ @autocast_custom_bwd
307
+ def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
308
+ q, k, v, a, b, gk, initial_state = ctx.saved_tensors
309
+ BT = ctx.chunk_size
310
+ cu_seqlens = ctx.cu_seqlens
311
+ scale = ctx.scale
312
+
313
+ return chunk_dplr_bwd(
314
+ q=q,
315
+ k=k,
316
+ v=v,
317
+ a=a,
318
+ b=b,
319
+ gk=gk,
320
+ scale=scale,
321
+ initial_state=initial_state,
322
+ do=do,
323
+ dht=dht,
324
+ BT=BT,
325
+ )
326
+
327
+
328
+ @torch.compiler.disable
329
+ def chunk_dplr_delta_rule(
330
+ q: torch.Tensor,
331
+ k: torch.Tensor,
332
+ v: torch.Tensor,
333
+ a: torch.Tensor,
334
+ b: torch.Tensor,
335
+ gk: torch.Tensor,
336
+ scale: Optional[float] = None,
337
+ initial_state: Optional[torch.Tensor] = None,
338
+ output_final_state: bool = False,
339
+ cu_seqlens: Optional[torch.LongTensor] = None,
340
+ ):
341
+ r"""
342
+ Main interface function for chunked delta rule attention.
343
+
344
+ 分块 Delta Rule 注意力机制的主要接口函数。
345
+
346
+ Args:
347
+ q (torch.Tensor):
348
+ queries of shape `[B, T, H, K]`
349
+ k (torch.Tensor):
350
+ keys of shape `[B, T, H, K]`
351
+ v (torch.Tensor):
352
+ values of shape `[B, T, H, V]`
353
+ a (torch.Tensor):
354
+ activations of shape `[B, T, H, K]`
355
+ b (torch.Tensor):
356
+ betas of shape `[B, T, H, K]`
357
+ gk (torch.Tensor):
358
+ gk of shape `[B, T, H, K]` decay term in log space!
359
+ scale (Optional[float]):
360
+ Scale factor for the RetNet attention scores.
361
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
362
+ initial_state (Optional[torch.Tensor]):
363
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
364
+ For equal-length input sequences, `N` equals the batch size `B`.
365
+ Default: `None`.
366
+ output_final_state (Optional[bool]):
367
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
368
+ cu_seqlens (torch.LongTensor):
369
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
370
+ consistent with the FlashAttention API.
371
+ head_first (Optional[bool]):
372
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
373
+ Default: `False`.
374
+
375
+ Returns:
376
+ o (torch.Tensor):
377
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
378
+ final_state (torch.Tensor):
379
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
380
+ """
381
+ if q.dtype == torch.float32:
382
+ warnings.warn(
383
+ """ChunkDeltaRuleFunction does not support float32 on some platforms. Please use bfloat16/float16.
384
+ If you want to use float32, please solve the issue by yourself.""",
385
+ category=RuntimeWarning,
386
+ stacklevel=2,
387
+ )
388
+ if cu_seqlens is not None:
389
+ if q.shape[0] != 1:
390
+ raise ValueError(
391
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
392
+ f"Please flatten variable-length inputs before processing."
393
+ )
394
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
395
+ raise ValueError(
396
+ f"The number of initial states is expected to be equal to the number of input sequences, "
397
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
398
+ )
399
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
400
+ o, final_state = ChunkDPLRDeltaRuleFunction.apply(
401
+ q,
402
+ k,
403
+ v,
404
+ a,
405
+ b,
406
+ gk,
407
+ scale,
408
+ initial_state,
409
+ output_final_state,
410
+ cu_seqlens,
411
+ )
412
+ return o, final_state
413
+
414
+
415
+ def chunk_rwkv7(
416
+ r: torch.Tensor,
417
+ k: torch.Tensor,
418
+ v: torch.Tensor,
419
+ a: torch.Tensor,
420
+ b: torch.Tensor,
421
+ w: torch.Tensor = None,
422
+ log_w: torch.Tensor = None,
423
+ scale: float = 1.0,
424
+ initial_state: torch.Tensor = None,
425
+ output_final_state: bool = True,
426
+ ):
427
+ """
428
+ Interface function for RWKV-7 attention.
429
+
430
+ RWKV-7 注意力机制的接口函数。
431
+ """
432
+
433
+ if w is not None:
434
+ log_w = -torch.exp(w)
435
+ else:
436
+ assert log_w is not None, "Either w or log_w must be provided!"
437
+
438
+ return chunk_dplr_delta_rule(
439
+ q=r,
440
+ k=k,
441
+ v=v,
442
+ a=a,
443
+ b=b,
444
+ gk=log_w,
445
+ scale=scale,
446
+ initial_state=initial_state,
447
+ output_final_state=output_final_state,
448
+ )
449
+
450
+
451
+ def transpose_head(x, head_first):
452
+ if head_first:
453
+ x = torch.permute(x, dims=(0, 2, 1, 3))
454
+ out = cast(x, torch.bfloat16).contiguous()
455
+ return out
456
+
457
+
458
+ def generalized_delta_rule(
459
+ r: torch.Tensor,
460
+ w: torch.Tensor,
461
+ k: torch.Tensor,
462
+ v: torch.Tensor,
463
+ a: torch.Tensor,
464
+ b: torch.Tensor,
465
+ initial_state: torch.Tensor = None,
466
+ output_final_state: bool = True,
467
+ head_first: bool = False,
468
+ ):
469
+ dtype = r.dtype
470
+ r = transpose_head(r, head_first)
471
+ k = transpose_head(k, head_first)
472
+ v = transpose_head(v, head_first)
473
+ a = transpose_head(a, head_first)
474
+ b = transpose_head(b, head_first)
475
+ w = transpose_head(w, head_first)
476
+ if w.device.type == "cuda":
477
+ out, state = chunk_rwkv7(
478
+ r=r,
479
+ k=k,
480
+ v=v,
481
+ a=a,
482
+ b=b,
483
+ w=w,
484
+ initial_state=initial_state,
485
+ output_final_state=output_final_state,
486
+ )
487
+ else:
488
+ from .native_keras_op import generalized_delta_rule
489
+
490
+ out, state = generalized_delta_rule(
491
+ r=r,
492
+ k=k,
493
+ v=v,
494
+ a=a,
495
+ b=b,
496
+ w=w,
497
+ initial_state=initial_state,
498
+ output_final_state=output_final_state,
499
+ )
500
+ out = transpose_head(out, head_first)
501
+ if output_final_state:
502
+ return out, cast(state, dtype)
503
+ else:
504
+ return out
@@ -0,0 +1,34 @@
1
+ # ---------- chunk_A ----------
2
+ from .chunk_A_bwd import (
3
+ chunk_dplr_bwd_kernel_intra,
4
+ chunk_dplr_bwd_dgk_kernel,
5
+ )
6
+ from .chunk_A_fwd import chunk_dplr_fwd_A_kernel_intra_sub_intra
7
+
8
+ # ---------- chunk_h ----------
9
+ from .chunk_h_bwd import chunk_dplr_bwd_kernel_dhu
10
+ from .chunk_h_fwd import chunk_dplr_fwd_kernel_h
11
+
12
+ # ---------- chunk_o ----------
13
+ from .chunk_o_bwd import (
14
+ chunk_dplr_bwd_kernel_dAu,
15
+ chunk_dplr_bwd_o_kernel,
16
+ chunk_dplr_bwd_kernel_dv,
17
+ )
18
+ from .chunk_o_fwd import chunk_dplr_fwd_kernel_o
19
+
20
+ # ---------- cumsum ----------
21
+ from .cumsum import chunk_rwkv6_fwd_cumsum_kernel
22
+
23
+ # ---------- wy_fast ----------
24
+ from .wy_fast_bwd import (
25
+ prepare_wy_repr_bwd_kernel,
26
+ )
27
+ from .wy_fast_fwd import (
28
+ prepare_wy_repr_fwd_kernel_chunk32,
29
+ prepare_wy_repr_fwd_kernel_chunk64,
30
+ wu_fwd_kernel,
31
+ )
32
+
33
+ # ---------- utils ----------
34
+ from .utils import *