rwkv-ops 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rwkv-ops might be problematic. Click here for more details.

Files changed (43) hide show
  1. rwkv_ops/__init__.py +26 -0
  2. rwkv_ops/rwkv7_kernel/__init__.py +153 -0
  3. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +221 -0
  4. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  5. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  6. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  7. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  8. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  9. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  10. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  11. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  12. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  13. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  14. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  15. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  16. rwkv_ops/rwkv7_kernel/native_keras_op.py +95 -0
  17. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  18. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  19. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  20. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  21. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  22. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  23. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  24. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  25. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  26. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  27. rwkv_ops/rwkv7_kernel/torch_op.py +523 -0
  28. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  29. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  30. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  31. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  32. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  33. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  34. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  35. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  36. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  37. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  38. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  39. rwkv_ops-0.1.0.dist-info/LICENSE.txt +201 -0
  40. rwkv_ops-0.1.0.dist-info/METADATA +118 -0
  41. rwkv_ops-0.1.0.dist-info/RECORD +43 -0
  42. rwkv_ops-0.1.0.dist-info/WHEEL +5 -0
  43. rwkv_ops-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,523 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ This file implements the forward and backward pass of a chunked delta rule attention mechanism,
4
+ optimized with Triton kernels for GPU acceleration. It includes functions for forward propagation,
5
+ backward gradient computation, and integration with PyTorch's autograd system.
6
+
7
+ 该文件实现了分块 Delta Rule 注意力机制的前向与反向传播,
8
+ 使用 Triton 内核进行 GPU 加速优化。包括前向传播、梯度反向传播函数,
9
+ 并集成了 PyTorch 的自动求导系统。
10
+ """
11
+
12
+ import warnings
13
+ from typing import Optional
14
+
15
+ import torch
16
+ import triton
17
+
18
+ # 导入内核实现模块 / Import kernel implementation modules
19
+ from .torch_kernel.chunk_A_bwd import chunk_dplr_bwd_dqk_intra
20
+ from .torch_kernel.chunk_A_fwd import chunk_dplr_fwd_intra
21
+ from .torch_kernel.chunk_h_bwd import chunk_dplr_bwd_dhu
22
+ from .torch_kernel.chunk_h_fwd import chunk_dplr_fwd_h
23
+ from .torch_kernel.chunk_o_bwd import (
24
+ chunk_dplr_bwd_dAu,
25
+ chunk_dplr_bwd_dv,
26
+ chunk_dplr_bwd_o,
27
+ )
28
+ from .torch_kernel.chunk_o_fwd import chunk_dplr_fwd_o
29
+ from .torch_kernel.wy_fast_bwd import chunk_dplr_bwd_wy
30
+ from .torch_kernel.wy_fast_fwd import prepare_wy_repr_fwd
31
+ from .torch_kernel.cumsum import chunk_rwkv6_fwd_cumsum
32
+ from .get_torch_devices_info import (
33
+ autocast_custom_bwd,
34
+ autocast_custom_fwd,
35
+ input_guard,
36
+ )
37
+
38
+
39
+ def cast(x, dtype):
40
+ """
41
+ Cast tensor x to specified dtype if not already in that format.
42
+
43
+ 如果张量 x 不是目标数据类型,则将其转换为目标类型。
44
+ """
45
+ if x is None or x.dtype == dtype:
46
+ return x
47
+ return x.to(dtype)
48
+
49
+
50
+ def chunk_dplr_fwd(
51
+ q: torch.Tensor,
52
+ k: torch.Tensor,
53
+ v: torch.Tensor,
54
+ a: torch.Tensor,
55
+ b: torch.Tensor,
56
+ gk: torch.Tensor,
57
+ scale: float = 1,
58
+ initial_state: torch.Tensor = None,
59
+ output_final_state: bool = True,
60
+ chunk_size: int = 16,
61
+ ):
62
+ """
63
+ Forward pass of chunked delta rule attention.
64
+
65
+ 分块 Delta Rule 注意力机制的前向传播。
66
+
67
+ Args:
68
+ q (torch.Tensor): Queries tensor [B, T, H, K]
69
+ k (torch.Tensor): Keys tensor [B, T, H, K]
70
+ v (torch.Tensor): Values tensor [B, T, H, V]
71
+ a (torch.Tensor): Activations tensor [B, T, H, K]
72
+ b (torch.Tensor): Betas tensor [B, T, H, K]
73
+ gk (torch.Tensor): Log decay tensor [B, T, H, K]
74
+ scale (float): Scale factor for attention scores
75
+ initial_state (Optional[torch.Tensor]): Initial state for recurrent processing
76
+ output_final_state (bool): Whether to return final state
77
+ chunk_size (int): Chunk size for processing
78
+
79
+ Returns:
80
+ o (torch.Tensor): Output tensor [B, T, H, V]
81
+ final_state (Optional[torch.Tensor]): Final state if requested
82
+ """
83
+ T = q.shape[1]
84
+ BT = min(chunk_size, max(triton.next_power_of_2(T), 16))
85
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
86
+
87
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
88
+ q=q,
89
+ k=k,
90
+ a=a,
91
+ b=b,
92
+ gi=gi,
93
+ ge=ge,
94
+ scale=scale,
95
+ chunk_size=BT,
96
+ )
97
+
98
+ del ge
99
+
100
+ # Compute WY representation
101
+ w, u, _ = prepare_wy_repr_fwd(ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT)
102
+
103
+ del A_ab, A_ak
104
+
105
+ h, v_new, final_state = chunk_dplr_fwd_h(
106
+ kg=kg,
107
+ bg=bg,
108
+ v=v,
109
+ w=w,
110
+ u=u,
111
+ gk=gi,
112
+ initial_state=initial_state,
113
+ output_final_state=output_final_state,
114
+ chunk_size=BT,
115
+ )
116
+
117
+ del u, kg, bg, gi
118
+
119
+ o = chunk_dplr_fwd_o(
120
+ qg=qg, v=v, v_new=v_new, A_qk=A_qk, A_qb=A_qb, h=h, chunk_size=BT
121
+ )
122
+ del v_new, h, A_qk, A_qb
123
+
124
+ return o, final_state
125
+
126
+
127
+ def chunk_dplr_bwd(
128
+ q: torch.Tensor,
129
+ k: torch.Tensor,
130
+ v: torch.Tensor,
131
+ a: torch.Tensor,
132
+ b: torch.Tensor,
133
+ gk: torch.Tensor,
134
+ initial_state: torch.Tensor,
135
+ scale,
136
+ do,
137
+ dht,
138
+ BT: int = 16,
139
+ ):
140
+ """
141
+ Backward pass of chunked delta rule attention.
142
+
143
+ 分块 Delta Rule 注意力机制的反向传播。
144
+
145
+ Args:
146
+ q (torch.Tensor): Queries tensor [B, T, H, K]
147
+ k (torch.Tensor): Keys tensor [B, T, H, K]
148
+ v (torch.Tensor): Values tensor [B, T, H, V]
149
+ a (torch.Tensor): Activations tensor [B, T, H, K]
150
+ b (torch.Tensor): Betas tensor [B, T, H, K]
151
+ gk (torch.Tensor): Log decay tensor [B, T, H, K]
152
+ initial_state (torch.Tensor): Initial state for recurrent processing
153
+ scale (float): Scale factor for attention scores
154
+ do (torch.Tensor): Gradient of outputs
155
+ dht (torch.Tensor): Gradient of final hidden state
156
+ BT (int): Chunk size for processing
157
+
158
+ Returns:
159
+ dq (torch.Tensor): Gradient of queries
160
+ dk (torch.Tensor): Gradient of keys
161
+ dv (torch.Tensor): Gradient of values
162
+ da (torch.Tensor): Gradient of activations
163
+ db (torch.Tensor): Gradient of betas
164
+ dgk (torch.Tensor): Gradient of log decays
165
+ dh0 (torch.Tensor): Gradient of initial state
166
+ """
167
+ gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT)
168
+ A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_dplr_fwd_intra(
169
+ q=q,
170
+ k=k,
171
+ a=a,
172
+ b=b,
173
+ gi=gi,
174
+ ge=ge,
175
+ scale=scale,
176
+ chunk_size=BT,
177
+ )
178
+ w, u, A_ab_inv = prepare_wy_repr_fwd(
179
+ ag=ag, A_ab=A_ab, A_ak=A_ak, v=v, chunk_size=BT
180
+ )
181
+ del A_ab
182
+ h, v_new, _ = chunk_dplr_fwd_h(
183
+ kg=kg, bg=bg, v=v, w=w, u=u, gk=gi, initial_state=initial_state, chunk_size=BT
184
+ )
185
+ del u
186
+
187
+ dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu(
188
+ v=v, v_new=v_new, do=do, A_qb=A_qb, scale=scale, chunk_size=BT
189
+ )
190
+
191
+ dh, dh0, dv_new = chunk_dplr_bwd_dhu(
192
+ qg=qg,
193
+ bg=bg,
194
+ w=w,
195
+ gk=gi,
196
+ h0=initial_state,
197
+ dht=dht,
198
+ do=do,
199
+ dv=dv_new_intra,
200
+ chunk_size=BT,
201
+ )
202
+
203
+ dv = chunk_dplr_bwd_dv(A_qk=A_qk, kg=kg, do=do, dh=dh, chunk_size=BT)
204
+ del A_qk
205
+
206
+ dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o(
207
+ k=kg,
208
+ b=bg,
209
+ v=v,
210
+ v_new=v_new,
211
+ do=do,
212
+ h=h,
213
+ dh=dh,
214
+ dv=dv,
215
+ w=w,
216
+ gk=gi,
217
+ chunk_size=BT,
218
+ scale=scale,
219
+ )
220
+ del v_new
221
+
222
+ dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy(
223
+ A_ab_inv=A_ab_inv,
224
+ A_ak=A_ak,
225
+ v=v,
226
+ ag=ag,
227
+ dw=dw,
228
+ du=dv_new,
229
+ dv0=dv,
230
+ chunk_size=BT,
231
+ )
232
+ del A_ak
233
+
234
+ dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra(
235
+ q=q,
236
+ k=k,
237
+ a=a,
238
+ b=b,
239
+ gi=gi,
240
+ ge=ge,
241
+ dAqk=dA_qk,
242
+ dAqb=dA_qb,
243
+ dAak=dA_ak,
244
+ dAab=dA_ab,
245
+ dgk_last=dgk_last,
246
+ dqg=dqg,
247
+ dkg=dkg,
248
+ dag=dag,
249
+ dbg=dbg,
250
+ chunk_size=BT,
251
+ scale=scale,
252
+ )
253
+
254
+ return (
255
+ dq.to(q),
256
+ dk.to(k),
257
+ dv.to(v),
258
+ da.to(a),
259
+ db.to(b),
260
+ dgk.to(gk),
261
+ None,
262
+ dh0,
263
+ None,
264
+ None,
265
+ )
266
+
267
+
268
+ class ChunkDPLRDeltaRuleFunction(torch.autograd.Function):
269
+ @staticmethod
270
+ @input_guard
271
+ @autocast_custom_fwd
272
+ def forward(
273
+ ctx,
274
+ q: torch.Tensor,
275
+ k: torch.Tensor,
276
+ v: torch.Tensor,
277
+ a: torch.Tensor,
278
+ b: torch.Tensor,
279
+ gk: torch.Tensor,
280
+ scale: float = 1,
281
+ initial_state: torch.Tensor = None,
282
+ output_final_state: bool = True,
283
+ cu_seqlens: Optional[torch.LongTensor] = None,
284
+ ):
285
+ """
286
+ Forward function with autograd support.
287
+
288
+ 支持自动求导的前向函数。
289
+ """
290
+ chunk_size = 16
291
+ o, final_state = chunk_dplr_fwd(
292
+ q=q,
293
+ k=k,
294
+ v=v,
295
+ a=a,
296
+ b=b,
297
+ gk=gk,
298
+ scale=scale,
299
+ initial_state=initial_state,
300
+ output_final_state=output_final_state,
301
+ chunk_size=chunk_size,
302
+ )
303
+ ctx.save_for_backward(q, k, v, a, b, gk, initial_state)
304
+ ctx.cu_seqlens = cu_seqlens
305
+ ctx.scale = scale
306
+ ctx.chunk_size = chunk_size
307
+ return o.to(q.dtype), final_state
308
+
309
+ @staticmethod
310
+ @input_guard
311
+ @autocast_custom_bwd
312
+ def backward(ctx, do: torch.Tensor, dht: torch.Tensor):
313
+ """
314
+ Backward function with autograd support.
315
+
316
+ 支持自动求导的反向函数。
317
+ """
318
+ q, k, v, a, b, gk, initial_state = ctx.saved_tensors
319
+ BT = ctx.chunk_size
320
+ cu_seqlens = ctx.cu_seqlens
321
+ scale = ctx.scale
322
+
323
+ return chunk_dplr_bwd(
324
+ q=q,
325
+ k=k,
326
+ v=v,
327
+ a=a,
328
+ b=b,
329
+ gk=gk,
330
+ scale=scale,
331
+ initial_state=initial_state,
332
+ do=do,
333
+ dht=dht,
334
+ BT=BT,
335
+ )
336
+
337
+
338
+ @torch.compiler.disable
339
+ def chunk_dplr_delta_rule(
340
+ q: torch.Tensor,
341
+ k: torch.Tensor,
342
+ v: torch.Tensor,
343
+ a: torch.Tensor,
344
+ b: torch.Tensor,
345
+ gk: torch.Tensor,
346
+ scale: Optional[float] = None,
347
+ initial_state: Optional[torch.Tensor] = None,
348
+ output_final_state: bool = False,
349
+ cu_seqlens: Optional[torch.LongTensor] = None,
350
+ ):
351
+ r"""
352
+ Main interface function for chunked delta rule attention.
353
+
354
+ 分块 Delta Rule 注意力机制的主要接口函数。
355
+
356
+ Args:
357
+ q (torch.Tensor):
358
+ queries of shape `[B, T, H, K]`
359
+ k (torch.Tensor):
360
+ keys of shape `[B, T, H, K]`
361
+ v (torch.Tensor):
362
+ values of shape `[B, T, H, V]`
363
+ a (torch.Tensor):
364
+ activations of shape `[B, T, H, K]`
365
+ b (torch.Tensor):
366
+ betas of shape `[B, T, H, K]`
367
+ gk (torch.Tensor):
368
+ gk of shape `[B, T, H, K]` decay term in log space!
369
+ scale (Optional[float]):
370
+ Scale factor for the RetNet attention scores.
371
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
372
+ initial_state (Optional[torch.Tensor]):
373
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
374
+ For equal-length input sequences, `N` equals the batch size `B`.
375
+ Default: `None`.
376
+ output_final_state (Optional[bool]):
377
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
378
+ cu_seqlens (torch.LongTensor):
379
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
380
+ consistent with the FlashAttention API.
381
+ head_first (Optional[bool]):
382
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
383
+ Default: `False`.
384
+
385
+ Returns:
386
+ o (torch.Tensor):
387
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
388
+ final_state (torch.Tensor):
389
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
390
+ """
391
+ if q.dtype == torch.float32:
392
+ warnings.warn(
393
+ """ChunkDeltaRuleFunction does not support float32 on some platforms. Please use bfloat16/float16.
394
+ If you want to use float32, please solve the issue by yourself.""",
395
+ category=RuntimeWarning,
396
+ stacklevel=2,
397
+ )
398
+ if cu_seqlens is not None:
399
+ if q.shape[0] != 1:
400
+ raise ValueError(
401
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
402
+ f"Please flatten variable-length inputs before processing."
403
+ )
404
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
405
+ raise ValueError(
406
+ f"The number of initial states is expected to be equal to the number of input sequences, "
407
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
408
+ )
409
+ scale = k.shape[-1] ** -0.5 if scale is None else scale
410
+ o, final_state = ChunkDPLRDeltaRuleFunction.apply(
411
+ q,
412
+ k,
413
+ v,
414
+ a,
415
+ b,
416
+ gk,
417
+ scale,
418
+ initial_state,
419
+ output_final_state,
420
+ cu_seqlens,
421
+ )
422
+ return o, final_state
423
+
424
+
425
+ def chunk_rwkv7(
426
+ r: torch.Tensor,
427
+ k: torch.Tensor,
428
+ v: torch.Tensor,
429
+ a: torch.Tensor,
430
+ b: torch.Tensor,
431
+ w: torch.Tensor = None,
432
+ log_w: torch.Tensor = None,
433
+ scale: float = 1.0,
434
+ initial_state: torch.Tensor = None,
435
+ output_final_state: bool = True,
436
+ ):
437
+ """
438
+ Interface function for RWKV-7 attention.
439
+
440
+ RWKV-7 注意力机制的接口函数。
441
+ """
442
+ if w is not None:
443
+ log_w = -torch.exp(w)
444
+ else:
445
+ assert log_w is not None, "Either w or log_w must be provided!"
446
+
447
+ return chunk_dplr_delta_rule(
448
+ q=r,
449
+ k=k,
450
+ v=v,
451
+ a=a,
452
+ b=b,
453
+ gk=log_w,
454
+ scale=scale,
455
+ initial_state=initial_state,
456
+ output_final_state=output_final_state,
457
+ )
458
+
459
+
460
+ def transpose_head(x, head_first):
461
+ """
462
+ Transpose between head-first and time-first formats.
463
+
464
+ 在 head-first 和 time-first 格式之间转置。
465
+ """
466
+ if head_first:
467
+ x = torch.permute(x, dims=(0, 2, 1, 3))
468
+ out = cast(x, torch.bfloat16).contiguous()
469
+ return out
470
+
471
+
472
+ def generalized_delta_rule(
473
+ r: torch.Tensor,
474
+ w: torch.Tensor,
475
+ k: torch.Tensor,
476
+ v: torch.Tensor,
477
+ a: torch.Tensor,
478
+ b: torch.Tensor,
479
+ initial_state: torch.Tensor = None,
480
+ output_final_state: bool = True,
481
+ head_first: bool = False,
482
+ ):
483
+ """
484
+ Generalized delta rule attention interface.
485
+
486
+ 泛化 Delta Rule 注意力机制接口。
487
+ """
488
+ dtype = r.dtype
489
+ r = transpose_head(r, head_first)
490
+ k = transpose_head(k, head_first)
491
+ v = transpose_head(v, head_first)
492
+ a = transpose_head(a, head_first)
493
+ b = transpose_head(b, head_first)
494
+ w = transpose_head(w, head_first)
495
+ if w.device.type == "cuda":
496
+ out, state = chunk_rwkv7(
497
+ r=r,
498
+ k=k,
499
+ v=v,
500
+ a=a,
501
+ b=b,
502
+ w=w,
503
+ initial_state=initial_state,
504
+ output_final_state=output_final_state,
505
+ )
506
+ else:
507
+ from .native_keras_op import generalized_delta_rule
508
+
509
+ out, state = generalized_delta_rule(
510
+ r=r,
511
+ k=k,
512
+ v=v,
513
+ a=a,
514
+ b=b,
515
+ w=w,
516
+ initial_state=initial_state,
517
+ output_final_state=output_final_state,
518
+ )
519
+ out = transpose_head(out, head_first)
520
+ if output_final_state:
521
+ return out, cast(state, dtype)
522
+ else:
523
+ return out
@@ -0,0 +1,34 @@
1
+ # ---------- chunk_A ----------
2
+ from .chunk_A_bwd import (
3
+ chunk_dplr_bwd_kernel_intra,
4
+ chunk_dplr_bwd_dgk_kernel,
5
+ )
6
+ from .chunk_A_fwd import chunk_dplr_fwd_A_kernel_intra_sub_intra
7
+
8
+ # ---------- chunk_h ----------
9
+ from .chunk_h_bwd import chunk_dplr_bwd_kernel_dhu
10
+ from .chunk_h_fwd import chunk_dplr_fwd_kernel_h
11
+
12
+ # ---------- chunk_o ----------
13
+ from .chunk_o_bwd import (
14
+ chunk_dplr_bwd_kernel_dAu,
15
+ chunk_dplr_bwd_o_kernel,
16
+ chunk_dplr_bwd_kernel_dv,
17
+ )
18
+ from .chunk_o_fwd import chunk_dplr_fwd_kernel_o
19
+
20
+ # ---------- cumsum ----------
21
+ from .cumsum import chunk_rwkv6_fwd_cumsum_kernel
22
+
23
+ # ---------- wy_fast ----------
24
+ from .wy_fast_bwd import (
25
+ prepare_wy_repr_bwd_kernel,
26
+ )
27
+ from .wy_fast_fwd import (
28
+ prepare_wy_repr_fwd_kernel_chunk32,
29
+ prepare_wy_repr_fwd_kernel_chunk64,
30
+ wu_fwd_kernel,
31
+ )
32
+
33
+ # ---------- utils ----------
34
+ from .utils import *