rwkv-ops 0.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. rwkv_ops/__init__.py +45 -0
  2. rwkv_ops/mhc_kernel/__init__.py +50 -0
  3. rwkv_ops/mhc_kernel/common_kernel/include/mhc_types.h +66 -0
  4. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_post_op.cuh +197 -0
  5. rwkv_ops/mhc_kernel/common_kernel/kernels/mhc_pre_op.cuh +212 -0
  6. rwkv_ops/mhc_kernel/common_kernel/kernels/rmsnorm.cuh +152 -0
  7. rwkv_ops/mhc_kernel/common_kernel/kernels/sinkhorn_knopp.cuh +158 -0
  8. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_aggregate.cuh +141 -0
  9. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_distribute.cuh +111 -0
  10. rwkv_ops/mhc_kernel/common_kernel/kernels/stream_mix.cuh +164 -0
  11. rwkv_ops/mhc_kernel/common_kernel/kernels/type_conversions.cuh +52 -0
  12. rwkv_ops/mhc_kernel/jax_kernel/CMakeLists.txt +47 -0
  13. rwkv_ops/mhc_kernel/jax_kernel/mhu_ffi.cu +652 -0
  14. rwkv_ops/mhc_kernel/jax_kernel/mhu_jax.py +939 -0
  15. rwkv_ops/mhc_kernel/native_keras_op.py +193 -0
  16. rwkv_ops/mhc_kernel/torch_kernel/mhc_cuda.cu +207 -0
  17. rwkv_ops/mhc_kernel/torch_kernel/mhc_op.cpp +296 -0
  18. rwkv_ops/mhc_kernel/torch_kernel/mhc_torch.py +306 -0
  19. rwkv_ops/rwkv6_kernel/__init__.py +120 -0
  20. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/gpu_ops.cpp +44 -0
  21. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernel_helpers.h +64 -0
  22. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/kernels.h +56 -0
  23. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/pybind11_kernel_helpers.h +41 -0
  24. rwkv_ops/rwkv6_kernel/jax_kernel_cuda/rwkv_kernels.cu +512 -0
  25. rwkv_ops/rwkv6_kernel/jax_kernel_hip/gpu_ops.cpp +44 -0
  26. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernel_helpers.h +64 -0
  27. rwkv_ops/rwkv6_kernel/jax_kernel_hip/kernels.h +56 -0
  28. rwkv_ops/rwkv6_kernel/jax_kernel_hip/pybind11_kernel_helpers.h +41 -0
  29. rwkv_ops/rwkv6_kernel/jax_kernel_hip/rwkv_kernels.hip +514 -0
  30. rwkv_ops/rwkv6_kernel/jax_rwkv_kernel.py +722 -0
  31. rwkv_ops/rwkv6_kernel/ops_rwkv_kernel.py +90 -0
  32. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_cuda.cu +397 -0
  33. rwkv_ops/rwkv6_kernel/torch_kernel/wkv6_op.cpp +93 -0
  34. rwkv_ops/rwkv6_kernel/torch_rwkv_kernel.py +305 -0
  35. rwkv_ops/rwkv7_kernel/__init__.py +113 -0
  36. rwkv_ops/rwkv7_kernel/get_jax_devices_info.py +220 -0
  37. rwkv_ops/rwkv7_kernel/get_torch_devices_info.py +250 -0
  38. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/CMakeLists.txt +42 -0
  39. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_ffi.cu +399 -0
  40. rwkv_ops/rwkv7_kernel/jax_cuda_kernel/wkv7_jax.py +311 -0
  41. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/CMakeLists.txt +42 -0
  42. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_ffi.cu +172 -0
  43. rwkv_ops/rwkv7_kernel/jax_cuda_kernel_single/wkv7_single_step_jax.py +190 -0
  44. rwkv_ops/rwkv7_kernel/jax_kernel/__init__.py +9 -0
  45. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_bwd.py +95 -0
  46. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_A_fwd.py +60 -0
  47. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_bwd.py +78 -0
  48. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_h_fwd.py +80 -0
  49. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_bwd.py +150 -0
  50. rwkv_ops/rwkv7_kernel/jax_kernel/chunk_o_fwd.py +45 -0
  51. rwkv_ops/rwkv7_kernel/jax_kernel/cumsum.py +34 -0
  52. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_bwd.py +61 -0
  53. rwkv_ops/rwkv7_kernel/jax_kernel/wy_fast_fwd.py +86 -0
  54. rwkv_ops/rwkv7_kernel/jax_op.py +382 -0
  55. rwkv_ops/rwkv7_kernel/mlx_op.py +118 -0
  56. rwkv_ops/rwkv7_kernel/native_keras_op.py +108 -0
  57. rwkv_ops/rwkv7_kernel/tf_eager_kernel.py +155 -0
  58. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_cuda.cu +235 -0
  59. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_op.cpp +63 -0
  60. rwkv_ops/rwkv7_kernel/torch_cuda_kernel/wkv7_torch.py +233 -0
  61. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_cuda.cu +101 -0
  62. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_op.cpp +56 -0
  63. rwkv_ops/rwkv7_kernel/torch_cuda_kernel_single/wkv7_single_step_torch.py +112 -0
  64. rwkv_ops/rwkv7_kernel/torch_kernel/__init__.py +13 -0
  65. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_bwd.py +96 -0
  66. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_A_fwd.py +64 -0
  67. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_bwd.py +74 -0
  68. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_h_fwd.py +75 -0
  69. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_bwd.py +148 -0
  70. rwkv_ops/rwkv7_kernel/torch_kernel/chunk_o_fwd.py +44 -0
  71. rwkv_ops/rwkv7_kernel/torch_kernel/cumsum.py +31 -0
  72. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_bwd.py +63 -0
  73. rwkv_ops/rwkv7_kernel/torch_kernel/wy_fast_fwd.py +79 -0
  74. rwkv_ops/rwkv7_kernel/torch_op.py +504 -0
  75. rwkv_ops/rwkv7_kernel/triton_kernel/__init__.py +34 -0
  76. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_bwd.py +328 -0
  77. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_A_fwd.py +186 -0
  78. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_bwd.py +157 -0
  79. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_h_fwd.py +160 -0
  80. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_bwd.py +382 -0
  81. rwkv_ops/rwkv7_kernel/triton_kernel/chunk_o_fwd.py +137 -0
  82. rwkv_ops/rwkv7_kernel/triton_kernel/cumsum.py +86 -0
  83. rwkv_ops/rwkv7_kernel/triton_kernel/utils.py +20 -0
  84. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_bwd.py +193 -0
  85. rwkv_ops/rwkv7_kernel/triton_kernel/wy_fast_fwd.py +326 -0
  86. rwkv_ops-0.6.1.dist-info/METADATA +495 -0
  87. rwkv_ops-0.6.1.dist-info/RECORD +89 -0
  88. rwkv_ops-0.6.1.dist-info/WHEEL +4 -0
  89. rwkv_ops-0.6.1.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,86 @@
1
+ import triton
2
+ import triton.language as tl
3
+
4
+ from ..triton_kernel.utils import use_cuda_graph
5
+
6
+
7
+ @triton.autotune(
8
+ configs=[
9
+ triton.Config({"BS": BS}, num_warps=num_warps, num_stages=num_stages)
10
+ for BS in [16, 32, 64]
11
+ for num_warps in [4, 8, 16]
12
+ for num_stages in [2, 3, 4]
13
+ ],
14
+ key=["S", "BT"],
15
+ use_cuda_graph=use_cuda_graph,
16
+ )
17
+ @triton.jit(do_not_specialize=["T"])
18
+ def chunk_rwkv6_fwd_cumsum_kernel(
19
+ s,
20
+ T,
21
+ oi,
22
+ oe,
23
+ H: tl.constexpr,
24
+ S: tl.constexpr,
25
+ BT: tl.constexpr,
26
+ BS: tl.constexpr,
27
+ ):
28
+ cu_seqlens = None
29
+ chunk_indices = None
30
+ i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
31
+ i_b, i_h = i_bh // H, i_bh % H
32
+ if False:
33
+ i_n, i_t = (
34
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
35
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
36
+ )
37
+ bos, eos = (
38
+ tl.load(cu_seqlens + i_n).to(tl.int32),
39
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
40
+ )
41
+ T = eos - bos
42
+ else:
43
+ bos, eos = i_b * T, i_b * T + T
44
+
45
+ o_i = tl.arange(0, BT)
46
+ m_i = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0).to(tl.float32)
47
+ m_e = tl.where(o_i[:, None] > o_i[None, :], 1.0, 0.0).to(tl.float32)
48
+
49
+ p_s = tl.make_block_ptr(
50
+ s + (bos * H + i_h) * S,
51
+ (T, S),
52
+ (H * S, 1),
53
+ (i_t * BT, i_s * BS),
54
+ (BT, BS),
55
+ (1, 0),
56
+ )
57
+ p_oi = tl.make_block_ptr(
58
+ oi + (bos * H + i_h) * S,
59
+ (T, S),
60
+ (H * S, 1),
61
+ (i_t * BT, i_s * BS),
62
+ (BT, BS),
63
+ (1, 0),
64
+ )
65
+ p_oe = tl.make_block_ptr(
66
+ oe + (bos * H + i_h) * S,
67
+ (T, S),
68
+ (H * S, 1),
69
+ (i_t * BT, i_s * BS),
70
+ (BT, BS),
71
+ (1, 0),
72
+ )
73
+ # [BT, BS]
74
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
75
+ b_oi = tl.dot(m_i, b_s)
76
+ b_oe = tl.dot(m_e, b_s)
77
+ tl.store(
78
+ p_oi,
79
+ b_oi.to(p_oi.dtype.element_ty, fp_downcast_rounding="rtne"),
80
+ boundary_check=(0, 1),
81
+ )
82
+ tl.store(
83
+ p_oe,
84
+ b_oe.to(p_oe.dtype.element_ty, fp_downcast_rounding="rtne"),
85
+ boundary_check=(0, 1),
86
+ )
@@ -0,0 +1,20 @@
1
+ import triton
2
+ import triton.language as tl
3
+
4
+ is_gather_supported = hasattr(triton.language, "gather")
5
+ if not is_gather_supported:
6
+
7
+ @triton.jit
8
+ def gather(src, index, axis, _builder=None):
9
+ # This is a fallback implementation when tl.gather is not supported
10
+ # In order to pass triton compiler, there is no actual gather operation
11
+ return src
12
+ else:
13
+ gather = tl.gather
14
+ exp = tl.exp
15
+ import keras
16
+
17
+ if keras.backend.backend() == "jax":
18
+ from ..get_jax_devices_info import *
19
+ else:
20
+ from ..get_torch_devices_info import *
@@ -0,0 +1,193 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import use_cuda_graph
9
+
10
+ triton_config = {}
11
+
12
+
13
+ @triton.autotune(
14
+ configs=[
15
+ triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages)
16
+ for num_warps in [2, 4, 8, 16]
17
+ for num_stages in [2, 3, 4]
18
+ ],
19
+ key=["BT", "BK", "BV"],
20
+ use_cuda_graph=use_cuda_graph,
21
+ )
22
+ @triton.jit(do_not_specialize=["T"])
23
+ def prepare_wy_repr_bwd_kernel(
24
+ A_ab_inv,
25
+ A_ak,
26
+ ag,
27
+ v,
28
+ dw,
29
+ du,
30
+ dv0,
31
+ T,
32
+ dAak,
33
+ dAab,
34
+ dv,
35
+ dag,
36
+ H: tl.constexpr,
37
+ K: tl.constexpr,
38
+ V: tl.constexpr,
39
+ BT: tl.constexpr,
40
+ BK: tl.constexpr,
41
+ BV: tl.constexpr,
42
+ ):
43
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
44
+ i_b, i_h = i_bh // H, i_bh % H
45
+ if False:
46
+ i_n, i_t = (
47
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
48
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
49
+ )
50
+ bos, eos = (
51
+ tl.load(cu_seqlens + i_n).to(tl.int32),
52
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
53
+ )
54
+ T = eos - bos
55
+ else:
56
+ bos, eos = i_b * T, i_b * T + T
57
+
58
+ p_Aak_t = tl.make_block_ptr(
59
+ A_ak + (bos * H + i_h) * BT,
60
+ (BT, T),
61
+ (1, H * BT),
62
+ (0, i_t * BT),
63
+ (BT, BT),
64
+ (0, 1),
65
+ )
66
+ p_Aab_inv_t = tl.make_block_ptr(
67
+ A_ab_inv + (bos * H + i_h) * BT,
68
+ (BT, T),
69
+ (1, H * BT),
70
+ (0, i_t * BT),
71
+ (BT, BT),
72
+ (0, 1),
73
+ )
74
+ p_dAak = tl.make_block_ptr(
75
+ dAak + (bos * H + i_h) * BT,
76
+ (T, BT),
77
+ (H * BT, 1),
78
+ (i_t * BT, 0),
79
+ (BT, BT),
80
+ (1, 0),
81
+ )
82
+ p_dAab = tl.make_block_ptr(
83
+ dAab + (bos * H + i_h) * BT,
84
+ (T, BT),
85
+ (H * BT, 1),
86
+ (i_t * BT, 0),
87
+ (BT, BT),
88
+ (1, 0),
89
+ )
90
+
91
+ b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1))
92
+ b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1))
93
+ b_A_ak_t = tl.where(
94
+ tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0
95
+ )
96
+ b_A_ab_inv_t = tl.where(
97
+ tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0
98
+ )
99
+ b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty)
100
+ b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32)
101
+
102
+ for i_v in range(tl.cdiv(V, BV)):
103
+ p_v = tl.make_block_ptr(
104
+ v + (bos * H + i_h) * V,
105
+ (T, V),
106
+ (H * V, 1),
107
+ (i_t * BT, i_v * BV),
108
+ (BT, BV),
109
+ (1, 0),
110
+ )
111
+ p_dv = tl.make_block_ptr(
112
+ dv + (bos * H + i_h) * V,
113
+ (T, V),
114
+ (H * V, 1),
115
+ (i_t * BT, i_v * BV),
116
+ (BT, BV),
117
+ (1, 0),
118
+ )
119
+ p_dv0 = tl.make_block_ptr(
120
+ dv0 + (bos * H + i_h) * V,
121
+ (T, V),
122
+ (H * V, 1),
123
+ (i_t * BT, i_v * BV),
124
+ (BT, BV),
125
+ (1, 0),
126
+ )
127
+ p_du = tl.make_block_ptr(
128
+ du + (bos * H + i_h) * V,
129
+ (T, V),
130
+ (H * V, 1),
131
+ (i_t * BT, i_v * BV),
132
+ (BT, BV),
133
+ (1, 0),
134
+ )
135
+ b_v = tl.load(p_v, boundary_check=(0, 1))
136
+ b_du = tl.load(p_du, boundary_check=(0, 1))
137
+ b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v))
138
+ b_dv0 = tl.load(p_dv0, boundary_check=(0, 1))
139
+ b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du)
140
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
141
+
142
+ m_i = tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :]
143
+ b_dA_tmp = tl.where(m_i, b_dA_tmp, 0)
144
+ b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp)
145
+ b_dA_ak = tl.where(m_i, b_dA_ak, 0)
146
+ tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1))
147
+ b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t)
148
+
149
+ for i_k in range(tl.cdiv(K, BK)):
150
+ p_ag = tl.make_block_ptr(
151
+ ag + (bos * H + i_h) * K,
152
+ (T, K),
153
+ (H * K, 1),
154
+ (i_t * BT, i_k * BK),
155
+ (BT, BK),
156
+ (1, 0),
157
+ )
158
+ p_dag = tl.make_block_ptr(
159
+ dag + (bos * H + i_h) * K,
160
+ (T, K),
161
+ (H * K, 1),
162
+ (i_t * BT, i_k * BK),
163
+ (BT, BK),
164
+ (1, 0),
165
+ )
166
+ p_dw = tl.make_block_ptr(
167
+ dw + (bos * H + i_h) * K,
168
+ (T, K),
169
+ (H * K, 1),
170
+ (i_t * BT, i_k * BK),
171
+ (BT, BK),
172
+ (1, 0),
173
+ )
174
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
175
+ b_dw = tl.load(p_dw, boundary_check=(0, 1))
176
+ b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag))
177
+ b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw)
178
+ tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1))
179
+
180
+ # if we know dL/dA^(-1), for dL/dA, we can use the following formula:
181
+ # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T
182
+ # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1.
183
+ # denote A = I - lower(A_ab), B = A^-1
184
+ # in the backward pass.
185
+ # dL/dA = -(B)^T @ (dL/dB) @ B^T
186
+ # dL/dA_ab = lower(B^T @ dL/dB @ B^T)
187
+ b_dA_ab_inv = tl.where(
188
+ tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0
189
+ )
190
+ b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv)
191
+ b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t)
192
+ b_dA_ab_inv = tl.where(m_i, b_dA_ab_inv, 0)
193
+ tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1))
@@ -0,0 +1,326 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from ..triton_kernel.utils import is_gather_supported, use_cuda_graph, gather
9
+
10
+
11
+ @triton.autotune(
12
+ configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16]],
13
+ key=["BT"],
14
+ use_cuda_graph=use_cuda_graph,
15
+ )
16
+ @triton.jit(do_not_specialize=["T"])
17
+ def prepare_wy_repr_fwd_kernel_chunk32(
18
+ A_ab,
19
+ T,
20
+ A_ab_inv,
21
+ H: tl.constexpr,
22
+ BT: tl.constexpr,
23
+ BC: tl.constexpr, # placeholder, do not delete
24
+ ):
25
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
26
+ i_b, i_h = i_bh // H, i_bh % H
27
+ if False:
28
+ i_n, i_t = (
29
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
30
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
31
+ )
32
+ bos, eos = (
33
+ tl.load(cu_seqlens + i_n).to(tl.int32),
34
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
35
+ )
36
+ T = eos - bos
37
+ else:
38
+ bos, eos = i_b * T, i_b * T + T
39
+ p_Aab = tl.make_block_ptr(
40
+ A_ab + (bos * H + i_h) * BT,
41
+ (T, BT),
42
+ (H * BT, 1),
43
+ (i_t * BT, 0),
44
+ (BT, BT),
45
+ (1, 0),
46
+ )
47
+ p_Aab_inv = tl.make_block_ptr(
48
+ A_ab_inv + (bos * H + i_h) * BT,
49
+ (T, BT),
50
+ (H * BT, 1),
51
+ (i_t * BT, 0),
52
+ (BT, BT),
53
+ (1, 0),
54
+ )
55
+ b_A_ab = tl.load(p_Aab, boundary_check=(0, 1))
56
+ b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0)
57
+ for i in range(1, BT):
58
+ mask = tl.arange(0, BT) == i
59
+ b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0)
60
+ b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i)
61
+ b_A_ab = tl.where(mask[:, None], b_a, b_A_ab)
62
+ b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :]
63
+ tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1))
64
+
65
+
66
+ @triton.autotune(
67
+ configs=[
68
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
69
+ for num_warps in [2, 4, 8]
70
+ for num_stages in [2, 3, 4]
71
+ ],
72
+ key=["BC"],
73
+ use_cuda_graph=use_cuda_graph,
74
+ )
75
+ @triton.jit(do_not_specialize=["T"])
76
+ def prepare_wy_repr_fwd_kernel_chunk64(
77
+ A_ab,
78
+ T,
79
+ A_ab_inv,
80
+ H: tl.constexpr,
81
+ BT: tl.constexpr,
82
+ BC: tl.constexpr,
83
+ GATHER_SUPPORTED: tl.constexpr = is_gather_supported,
84
+ ):
85
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
86
+ i_b, i_h = i_bh // H, i_bh % H
87
+ if False:
88
+ i_n, i_t = (
89
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
90
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
91
+ )
92
+ bos, eos = (
93
+ tl.load(cu_seqlens + i_n).to(tl.int32),
94
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
95
+ )
96
+ T = eos - bos
97
+ else:
98
+ bos, eos = i_b * T, i_b * T + T
99
+
100
+ p_A1 = tl.make_block_ptr(
101
+ A_ab + (bos * H + i_h) * BT,
102
+ (T, BT),
103
+ (H * BT, 1),
104
+ (i_t * BT, 0),
105
+ (BC, BC),
106
+ (1, 0),
107
+ )
108
+ p_A2 = tl.make_block_ptr(
109
+ A_ab + (bos * H + i_h) * BT,
110
+ (T, BT),
111
+ (H * BT, 1),
112
+ (i_t * BT + BC, BC),
113
+ (BC, BC),
114
+ (1, 0),
115
+ )
116
+ p_A3 = tl.make_block_ptr(
117
+ A_ab + (bos * H + i_h) * BT,
118
+ (T, BT),
119
+ (H * BT, 1),
120
+ (i_t * BT + BC, 0),
121
+ (BC, BC),
122
+ (1, 0),
123
+ )
124
+ p_A_inv1 = tl.make_block_ptr(
125
+ A_ab_inv + (bos * H + i_h) * BT,
126
+ (T, BT),
127
+ (H * BT, 1),
128
+ (i_t * BT, 0),
129
+ (BC, BC),
130
+ (1, 0),
131
+ )
132
+ p_A_inv2 = tl.make_block_ptr(
133
+ A_ab_inv + (bos * H + i_h) * BT,
134
+ (T, BT),
135
+ (H * BT, 1),
136
+ (i_t * BT + BC, BC),
137
+ (BC, BC),
138
+ (1, 0),
139
+ )
140
+ p_A_inv3 = tl.make_block_ptr(
141
+ A_ab_inv + (bos * H + i_h) * BT,
142
+ (T, BT),
143
+ (H * BT, 1),
144
+ (i_t * BT + BC, 0),
145
+ (BC, BC),
146
+ (1, 0),
147
+ )
148
+ p_A_inv4 = tl.make_block_ptr(
149
+ A_ab_inv + (bos * H + i_h) * BT,
150
+ (T, BT),
151
+ (H * BT, 1),
152
+ (i_t * BT, BC),
153
+ (BC, BC),
154
+ (1, 0),
155
+ )
156
+
157
+ b_A = tl.load(p_A1, boundary_check=(0, 1))
158
+ b_A2 = tl.load(p_A2, boundary_check=(0, 1))
159
+ b_A3 = tl.load(p_A3, boundary_check=(0, 1))
160
+ b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0)
161
+ b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0)
162
+
163
+ for i in range(1, BC):
164
+ if GATHER_SUPPORTED:
165
+ row_idx = tl.full([1, BC], i, dtype=tl.int16)
166
+ # [1, BK] -> [BK]
167
+ b_a = tl.sum(gather(b_A, row_idx, axis=0), 0)
168
+ b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0)
169
+ else:
170
+ mask = tl.arange(0, BC) == i
171
+ b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
172
+ b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
173
+ mask = tl.arange(0, BC) == i
174
+ # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0)
175
+ # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0)
176
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i)
177
+ b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i)
178
+ b_A = tl.where(mask[:, None], b_a, b_A)
179
+ b_A2 = tl.where(mask[:, None], b_a2, b_A2)
180
+
181
+ # blockwise computation of lower triangular matrix's inverse
182
+ # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1]
183
+ b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
184
+ b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :]
185
+ b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A)
186
+ # tl.debug_barrier()
187
+ tl.store(
188
+ p_A_inv1,
189
+ b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"),
190
+ boundary_check=(0, 1),
191
+ )
192
+ tl.store(
193
+ p_A_inv2,
194
+ b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"),
195
+ boundary_check=(0, 1),
196
+ )
197
+ tl.store(
198
+ p_A_inv3,
199
+ b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"),
200
+ boundary_check=(0, 1),
201
+ )
202
+ # causal mask
203
+ tl.store(
204
+ p_A_inv4,
205
+ tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty),
206
+ boundary_check=(0, 1),
207
+ )
208
+
209
+
210
+ @triton.autotune(
211
+ configs=[
212
+ triton.Config({}, num_warps=num_warps, num_stages=num_stages)
213
+ for num_warps in [2, 4, 8, 16]
214
+ for num_stages in [2, 3, 4]
215
+ ],
216
+ key=["H", "K", "V", "BT", "BK", "BV"],
217
+ use_cuda_graph=use_cuda_graph,
218
+ )
219
+ @triton.jit(do_not_specialize=["T"])
220
+ def wu_fwd_kernel(
221
+ ag,
222
+ v,
223
+ A_ab_inv,
224
+ A_ak,
225
+ T,
226
+ w,
227
+ u,
228
+ H: tl.constexpr,
229
+ K: tl.constexpr,
230
+ V: tl.constexpr,
231
+ BT: tl.constexpr,
232
+ BK: tl.constexpr,
233
+ BV: tl.constexpr,
234
+ ):
235
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
236
+ i_b, i_h = i_bh // H, i_bh % H
237
+ if False:
238
+ i_n, i_t = (
239
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
240
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
241
+ )
242
+ bos, eos = (
243
+ tl.load(cu_seqlens + i_n).to(tl.int32),
244
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
245
+ )
246
+ T = eos - bos
247
+ else:
248
+ bos, eos = i_b * T, i_b * T + T
249
+ o_s = tl.arange(0, BT)
250
+
251
+ p_A_ab_inv = tl.make_block_ptr(
252
+ A_ab_inv + (bos * H + i_h) * BT,
253
+ (T, BT),
254
+ (H * BT, 1),
255
+ (i_t * BT, 0),
256
+ (BT, BT),
257
+ (1, 0),
258
+ )
259
+ p_A_ak = tl.make_block_ptr(
260
+ A_ak + (bos * H + i_h) * BT,
261
+ (T, BT),
262
+ (H * BT, 1),
263
+ (i_t * BT, 0),
264
+ (BT, BT),
265
+ (1, 0),
266
+ )
267
+
268
+ b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1))
269
+ b_Aak = tl.load(p_A_ak, boundary_check=(0, 1))
270
+ b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0)
271
+ b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0)
272
+ # let's use tf32 here
273
+ b_Aak = tl.dot(b_Aab_inv, b_Aak)
274
+ # (SY 01/04) should be bf16 or tf32? To verify.
275
+ b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne")
276
+ b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne")
277
+
278
+ for i_k in range(tl.cdiv(K, BK)):
279
+ p_ag = tl.make_block_ptr(
280
+ ag + (bos * H + i_h) * K,
281
+ (T, K),
282
+ (H * K, 1),
283
+ (i_t * BT, i_k * BK),
284
+ (BT, BK),
285
+ (1, 0),
286
+ )
287
+ p_w = tl.make_block_ptr(
288
+ w + (bos * H + i_h) * K,
289
+ (T, K),
290
+ (H * K, 1),
291
+ (i_t * BT, i_k * BK),
292
+ (BT, BK),
293
+ (1, 0),
294
+ )
295
+ b_ag = tl.load(p_ag, boundary_check=(0, 1))
296
+ b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16
297
+ tl.store(
298
+ p_w,
299
+ b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"),
300
+ boundary_check=(0, 1),
301
+ )
302
+
303
+ for i_v in range(tl.cdiv(V, BV)):
304
+ p_v = tl.make_block_ptr(
305
+ v + (bos * H + i_h) * V,
306
+ (T, V),
307
+ (H * V, 1),
308
+ (i_t * BT, i_v * BV),
309
+ (BT, BV),
310
+ (1, 0),
311
+ )
312
+ p_u = tl.make_block_ptr(
313
+ u + (bos * H + i_h) * V,
314
+ (T, V),
315
+ (H * V, 1),
316
+ (i_t * BT, i_v * BV),
317
+ (BT, BV),
318
+ (1, 0),
319
+ )
320
+ b_v = tl.load(p_v, boundary_check=(0, 1))
321
+ b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16
322
+ tl.store(
323
+ p_u,
324
+ b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"),
325
+ boundary_check=(0, 1),
326
+ )