cuequivariance-ops-cu12 0.6.0__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cuequivariance-ops-cu12 might be problematic. Click here for more details.

Files changed (37) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/nvtx.hpp +29 -0
  6. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  7. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  8. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  9. cuequivariance_ops/equivariance/indexed_linear.hh +36 -0
  10. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  11. cuequivariance_ops/equivariance/run_fmha_cudafree.h +77 -0
  12. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  13. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  14. cuequivariance_ops/lib/libcue_ops.so +0 -0
  15. cuequivariance_ops/sleep.hh +18 -0
  16. cuequivariance_ops/triton/__init__.py +66 -0
  17. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37192 -0
  18. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  19. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  20. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  21. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  27. cuequivariance_ops/triton/cache_manager.py +259 -0
  28. cuequivariance_ops/triton/fused_layer_norm_triton.py +518 -0
  29. cuequivariance_ops/triton/gated_gemm_triton.py +380 -0
  30. cuequivariance_ops/triton/pair_bias.py +324 -0
  31. cuequivariance_ops/triton/tuning_decorator.py +177 -0
  32. cuequivariance_ops/triton/utils.py +28 -0
  33. cuequivariance_ops_cu12-0.6.0.dist-info/METADATA +182 -0
  34. cuequivariance_ops_cu12-0.6.0.dist-info/RECORD +37 -0
  35. cuequivariance_ops_cu12-0.6.0.dist-info/WHEEL +6 -0
  36. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/LICENSE +142 -0
  37. cuequivariance_ops_cu12-0.6.0.dist-info/licenses/Third_party_attr.txt +24 -0
@@ -0,0 +1,380 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+
12
+ import triton
13
+ import triton.language as tl
14
+
15
+ from cuequivariance_ops.triton.utils import cvt_tf32_rn
16
+
17
+
18
+ @triton.jit
19
+ def fused_sigmoid_gated_dual_gemm_forward_kernel(
20
+ # inputs
21
+ x1_ptr,
22
+ x2_ptr,
23
+ w1_ptr,
24
+ w2_ptr,
25
+ b1_ptr,
26
+ b2_ptr,
27
+ mask_ptr,
28
+ M,
29
+ N,
30
+ K,
31
+ # outputs
32
+ o_ptr,
33
+ TILE_M: tl.constexpr,
34
+ TILE_N: tl.constexpr,
35
+ TILE_K: tl.constexpr,
36
+ PRECISION: tl.constexpr,
37
+ APPLY_MASK: tl.constexpr,
38
+ TRANSPOSE_OUT: tl.constexpr,
39
+ TWO_INPUTS: tl.constexpr,
40
+ HAS_B1: tl.constexpr,
41
+ HAS_B2: tl.constexpr,
42
+ ):
43
+ # fully gated GEMM kernel with optional mask at the end
44
+ pid_m = tl.program_id(axis=0)
45
+ pid_n = tl.program_id(axis=1)
46
+
47
+ start_m = pid_m * TILE_M
48
+ start_n = pid_n * TILE_N
49
+
50
+ offs_xm = start_m + tl.arange(0, TILE_M)
51
+ offs_wn = start_n + tl.arange(0, TILE_N)
52
+ offs_k = tl.arange(0, TILE_K)
53
+
54
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
55
+ if TWO_INPUTS:
56
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
57
+
58
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
59
+
60
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
61
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
62
+
63
+ mask_m = offs_xm < M
64
+
65
+ if TWO_INPUTS:
66
+ for _ in range(0, tl.cdiv(K, TILE_K)):
67
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
68
+ w1_ptr.type.element_ty
69
+ )
70
+ w1_ptrs = w1_ptr + w_tile_offs
71
+ w1 = tl.load(w1_ptrs)
72
+ if PRECISION == 0:
73
+ acc_1 = tl.dot(x1, w1, acc_1)
74
+ elif PRECISION == 1:
75
+ x1 = cvt_tf32_rn(x1)
76
+ w1 = cvt_tf32_rn(w1)
77
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
78
+ elif PRECISION == 2:
79
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
80
+ elif PRECISION == 3:
81
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
82
+ else:
83
+ tl.static_assert(
84
+ False,
85
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
86
+ )
87
+
88
+ x1_ptrs += TILE_K
89
+ w1_ptr += TILE_K
90
+
91
+ for _ in range(0, tl.cdiv(K, TILE_K)):
92
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
93
+ w2_ptr.type.element_ty
94
+ )
95
+ w2_ptrs = w2_ptr + w_tile_offs
96
+ w2 = tl.load(w2_ptrs)
97
+ if PRECISION == 0:
98
+ acc_2 = tl.dot(x2, w2, acc_2)
99
+ elif PRECISION == 1:
100
+ x2 = cvt_tf32_rn(x2)
101
+ w2 = cvt_tf32_rn(w2)
102
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
103
+ elif PRECISION == 2:
104
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
105
+ elif PRECISION == 3:
106
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
107
+ else:
108
+ tl.static_assert(
109
+ False,
110
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
111
+ )
112
+
113
+ x2_ptrs += TILE_K
114
+ w2_ptr += TILE_K
115
+
116
+ else:
117
+ for _ in range(0, tl.cdiv(K, TILE_K)):
118
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
119
+ w1_ptr.type.element_ty
120
+ )
121
+
122
+ w1_ptrs = w1_ptr + w_tile_offs
123
+ w1 = tl.load(w1_ptrs)
124
+ if PRECISION == 0:
125
+ acc_1 = tl.dot(x, w1, acc_1)
126
+ elif PRECISION == 1:
127
+ x = cvt_tf32_rn(x)
128
+ w1 = cvt_tf32_rn(w1)
129
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
130
+ elif PRECISION == 2:
131
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
132
+ elif PRECISION == 3:
133
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
134
+ else:
135
+ tl.static_assert(
136
+ False,
137
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
138
+ )
139
+
140
+ w2_ptrs = w2_ptr + w_tile_offs
141
+ w2 = tl.load(w2_ptrs)
142
+ if PRECISION == 0:
143
+ acc_2 = tl.dot(x, w2, acc_2)
144
+ elif PRECISION == 1:
145
+ x = cvt_tf32_rn(x)
146
+ w2 = cvt_tf32_rn(w2)
147
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
148
+ elif PRECISION == 2:
149
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
150
+ elif PRECISION == 3:
151
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
152
+ else:
153
+ tl.static_assert(
154
+ False,
155
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
156
+ )
157
+
158
+ x1_ptrs += TILE_K
159
+ w1_ptr += TILE_K
160
+ w2_ptr += TILE_K
161
+
162
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
163
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
164
+
165
+ if HAS_B1:
166
+ b1_ptrs = b1_ptr + offs_on
167
+ b1_tile = tl.load(b1_ptrs).to(tl.float32)
168
+ acc_1 += b1_tile
169
+
170
+ if HAS_B2:
171
+ b2_ptrs = b2_ptr + offs_on
172
+ b2_tile = tl.load(b2_ptrs).to(tl.float32)
173
+ acc_2 += b2_tile
174
+
175
+ acc_1 = 1.0 / (1.0 + tl.exp(-acc_1))
176
+ acc_gated = acc_1 * acc_2
177
+
178
+ if APPLY_MASK:
179
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0).to(tl.float32)
180
+ acc_gated = acc_gated * mask[:, None]
181
+
182
+ if TRANSPOSE_OUT:
183
+ o_ptrs = o_ptr + offs_on[None, :] * M + offs_om[:, None]
184
+ else:
185
+ o_ptrs = o_ptr + offs_om[:, None] * N + offs_on[None, :]
186
+
187
+ o_mask = offs_om[:, None] < M
188
+ tl.store(o_ptrs, acc_gated, mask=o_mask)
189
+
190
+
191
+ @triton.jit
192
+ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel(
193
+ # inputs
194
+ grad_o_ptr,
195
+ x1_ptr,
196
+ x2_ptr,
197
+ w1_ptr,
198
+ w2_ptr,
199
+ b1_ptr,
200
+ b2_ptr,
201
+ mask_ptr,
202
+ M,
203
+ N,
204
+ K,
205
+ # outputs
206
+ grad_xw1_ptr,
207
+ grad_xw2_ptr,
208
+ grad_mask_ptr,
209
+ TILE_M: tl.constexpr,
210
+ TILE_N: tl.constexpr,
211
+ TILE_K: tl.constexpr,
212
+ PRECISION: tl.constexpr,
213
+ APPLY_MASK: tl.constexpr,
214
+ TRANSPOSE_OUT: tl.constexpr,
215
+ TWO_INPUTS: tl.constexpr,
216
+ HAS_B1: tl.constexpr,
217
+ HAS_B2: tl.constexpr,
218
+ ):
219
+ # fully gated GEMM kernel with optional mask at the end
220
+ pid_m = tl.program_id(axis=0)
221
+ pid_n = tl.program_id(axis=1)
222
+
223
+ start_m = pid_m * TILE_M
224
+ start_n = pid_n * TILE_N
225
+
226
+ offs_xm = start_m + tl.arange(0, TILE_M)
227
+ offs_wn = start_n + tl.arange(0, TILE_N)
228
+ offs_k = tl.arange(0, TILE_K)
229
+
230
+ x1_ptrs = x1_ptr + (offs_xm[:, None] * K + offs_k[None, :])
231
+ if TWO_INPUTS:
232
+ x2_ptrs = x2_ptr + (offs_xm[:, None] * K + offs_k[None, :])
233
+ w_tile_offs = offs_wn[None, :] * K + offs_k[:, None]
234
+
235
+ acc_1 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
236
+ acc_2 = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
237
+
238
+ mask_m = offs_xm < M
239
+
240
+ if TWO_INPUTS:
241
+ # recompute acc1 and acc2
242
+ for _ in range(0, tl.cdiv(K, TILE_K)):
243
+ x1 = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
244
+ w1_ptr.type.element_ty
245
+ )
246
+ w1_ptrs = w1_ptr + w_tile_offs
247
+ w1 = tl.load(w1_ptrs)
248
+
249
+ if PRECISION == 0:
250
+ acc_1 = tl.dot(x1, w1, acc_1)
251
+ elif PRECISION == 1:
252
+ x1 = cvt_tf32_rn(x1)
253
+ w1 = cvt_tf32_rn(w1)
254
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32")
255
+ elif PRECISION == 2:
256
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="tf32x3")
257
+ elif PRECISION == 3:
258
+ acc_1 = tl.dot(x1, w1, acc_1, input_precision="ieee")
259
+ else:
260
+ tl.static_assert(
261
+ False,
262
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
263
+ )
264
+
265
+ x1_ptrs += TILE_K
266
+ w1_ptr += TILE_K
267
+
268
+ for _ in range(0, tl.cdiv(K, TILE_K)):
269
+ x2 = tl.load(x2_ptrs, mask=mask_m[:, None], other=0.0).to(
270
+ w2_ptr.type.element_ty
271
+ )
272
+ w2_ptrs = w2_ptr + w_tile_offs
273
+ w2 = tl.load(w2_ptrs)
274
+
275
+ if PRECISION == 0:
276
+ acc_2 = tl.dot(x2, w2, acc_2)
277
+ elif PRECISION == 1:
278
+ x2 = cvt_tf32_rn(x2)
279
+ w2 = cvt_tf32_rn(w2)
280
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32")
281
+ elif PRECISION == 2:
282
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="tf32x3")
283
+ elif PRECISION == 3:
284
+ acc_2 = tl.dot(x2, w2, acc_2, input_precision="ieee")
285
+ else:
286
+ tl.static_assert(
287
+ False,
288
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
289
+ )
290
+
291
+ x2_ptrs += TILE_K
292
+ w2_ptr += TILE_K
293
+
294
+ else:
295
+ # recompute acc1 and acc2
296
+ for _ in range(0, tl.cdiv(K, TILE_K)):
297
+ x = tl.load(x1_ptrs, mask=mask_m[:, None], other=0.0).to(
298
+ w1_ptr.type.element_ty
299
+ )
300
+
301
+ w1_ptrs = w1_ptr + w_tile_offs
302
+ w1 = tl.load(w1_ptrs)
303
+ if PRECISION == 0:
304
+ acc_1 = tl.dot(x, w1, acc_1)
305
+ elif PRECISION == 1:
306
+ x = cvt_tf32_rn(x)
307
+ w1 = cvt_tf32_rn(w1)
308
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32")
309
+ elif PRECISION == 2:
310
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="tf32x3")
311
+ elif PRECISION == 3:
312
+ acc_1 = tl.dot(x, w1, acc_1, input_precision="ieee")
313
+ else:
314
+ tl.static_assert(
315
+ False,
316
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
317
+ )
318
+
319
+ w2_ptrs = w2_ptr + w_tile_offs
320
+ w2 = tl.load(w2_ptrs)
321
+ if PRECISION == 0:
322
+ acc_2 = tl.dot(x, w2, acc_2)
323
+ elif PRECISION == 1:
324
+ x = cvt_tf32_rn(x)
325
+ w2 = cvt_tf32_rn(w2)
326
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32")
327
+ elif PRECISION == 2:
328
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="tf32x3")
329
+ elif PRECISION == 3:
330
+ acc_2 = tl.dot(x, w2, acc_2, input_precision="ieee")
331
+ else:
332
+ tl.static_assert(
333
+ False,
334
+ "PRECISION must be 0 (default), 1 (tf32), 2 (tf32x3) or 3 (ieee)",
335
+ )
336
+
337
+ x1_ptrs += TILE_K
338
+ w1_ptr += TILE_K
339
+ w2_ptr += TILE_K
340
+
341
+ offs_om = pid_m * TILE_M + tl.arange(0, TILE_M)
342
+ offs_on = pid_n * TILE_N + tl.arange(0, TILE_N)
343
+
344
+ if HAS_B1:
345
+ b1_ptrs = b1_ptr + offs_on
346
+ b1_tile = tl.load(b1_ptrs).to(tl.float32)
347
+ acc_1 += b1_tile
348
+
349
+ if HAS_B2:
350
+ b2_ptrs = b2_ptr + offs_on
351
+ b2_tile = tl.load(b2_ptrs).to(tl.float32)
352
+ acc_2 += b2_tile
353
+
354
+ if TRANSPOSE_OUT:
355
+ grad_o_ptrs = grad_o_ptr + offs_on[None, :] * M + offs_om[:, None]
356
+ else:
357
+ grad_o_ptrs = grad_o_ptr + offs_om[:, None] * N + offs_on[None, :]
358
+
359
+ grad_o = tl.load(grad_o_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
360
+
361
+ acc_sig = 1.0 / (1.0 + tl.exp(-acc_1))
362
+
363
+ if APPLY_MASK:
364
+ tmp = acc_sig * acc_2
365
+ grad_mask = grad_o * tmp
366
+ grad_mask = tl.sum(grad_mask, axis=1)
367
+ grad_mask_ptrs = grad_mask_ptr + pid_n * M + offs_om
368
+ tl.store(grad_mask_ptrs, grad_mask, mask=mask_m)
369
+
370
+ mask = tl.load(mask_ptr + offs_om, mask=mask_m, other=0.0).to(tl.float32)
371
+ grad_o = grad_o * mask[:, None]
372
+
373
+ grad_xw2 = grad_o * acc_sig
374
+ grad_xw2_ptrs = grad_xw2_ptr + offs_om[:, None] * N + offs_on[None, :]
375
+ tl.store(grad_xw2_ptrs, grad_xw2, mask=mask_m[:, None])
376
+
377
+ tmp = (1.0 - acc_sig) * acc_sig
378
+ grad_xw1 = grad_o * acc_2 * tmp
379
+ grad_xw1_ptrs = grad_xw1_ptr + offs_om[:, None] * N + offs_on[None, :]
380
+ tl.store(grad_xw1_ptrs, grad_xw1, mask=mask_m[:, None])
@@ -0,0 +1,324 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ @triton.jit
16
+ def pair_bias_norm_linear_mask_forward_kernel(
17
+ z_ptr,
18
+ mask_ptr,
19
+ w_proj_z_ptr,
20
+ b_proj_z_ptr,
21
+ w_ln_ptr,
22
+ b_ln_ptr,
23
+ U,
24
+ V,
25
+ multiplicity,
26
+ out_mask_ptr,
27
+ z_norm_ptr,
28
+ mean_ptr,
29
+ rstd_ptr,
30
+ TILE_V: tl.constexpr,
31
+ TILE_K: tl.constexpr,
32
+ NUM_HEADS: tl.constexpr,
33
+ NUM_HEADS_PER_BLK: tl.constexpr,
34
+ DIM_Z: tl.constexpr,
35
+ INF: tl.constexpr,
36
+ EPS: tl.constexpr,
37
+ ELEMENTWISE_AFFINE: tl.constexpr,
38
+ IS_TRAINING: tl.constexpr,
39
+ HAS_BIAS: tl.constexpr,
40
+ MASK_WITH_MULTIPLICITY: tl.constexpr,
41
+ ):
42
+ # prepare single mask
43
+ # z: B x U x V x D -> z' B x H x U x V
44
+ # mask: B x V
45
+ # out_mask = z' + (1 - mask) * inf
46
+
47
+ pid_v = tl.program_id(0)
48
+ pid_u = tl.program_id(1)
49
+ head_batch_idx = tl.program_id(2)
50
+ NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
51
+ batch_idx = head_batch_idx // NUM_HEAD_BLKS
52
+ head_idx = head_batch_idx % NUM_HEAD_BLKS
53
+
54
+ stride_vz = V * DIM_Z
55
+ stride_uv = U * V
56
+ stride_uvz = U * V * DIM_Z
57
+
58
+ offs_u = pid_u
59
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
60
+ offs_k = tl.arange(0, TILE_K)
61
+ offs_z = tl.arange(0, DIM_Z)
62
+ mask_v = offs_v < V
63
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
64
+ mask_head = offs_head < NUM_HEADS
65
+
66
+ z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
67
+ z_ptrs += offs_v[:, None] * DIM_Z + offs_z[None, :]
68
+
69
+ z_tile_full = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
70
+
71
+ mean = tl.sum(z_tile_full, axis=1) / DIM_Z
72
+ rstd = z_tile_full - mean[:, None]
73
+ rstd = rstd * rstd
74
+ rstd = tl.sum(rstd, axis=1) / DIM_Z
75
+ rstd = tl.rsqrt(rstd + EPS)
76
+
77
+ if IS_TRAINING:
78
+ mean_ptrs = mean_ptr + batch_idx * stride_uv
79
+ mean_ptrs += offs_u * V + offs_v
80
+ tl.store(mean_ptrs, mean, mask=mask_v)
81
+
82
+ rstd_ptrs = rstd_ptr + batch_idx * stride_uv
83
+ rstd_ptrs += offs_u * V + offs_v
84
+ tl.store(rstd_ptrs, rstd, mask=mask_v)
85
+
86
+ z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
87
+ z_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
88
+ w_ln_ptrs = w_ln_ptr + offs_k
89
+ b_ln_ptrs = b_ln_ptr + offs_k
90
+ w_proj_ptrs = w_proj_z_ptr + (offs_head[None, :] * DIM_Z + offs_k[:, None])
91
+
92
+ if IS_TRAINING:
93
+ z_norm_ptrs = z_norm_ptr + batch_idx * stride_uvz + offs_u * stride_vz
94
+ z_norm_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
95
+
96
+ num_tiles_k = DIM_Z // TILE_K
97
+ acc = tl.zeros((TILE_V, NUM_HEADS_PER_BLK), dtype=tl.float32)
98
+
99
+ for _ in range(0, num_tiles_k):
100
+ z_tile = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
101
+ z_tile = (z_tile - mean[:, None]) * rstd[:, None]
102
+
103
+ if ELEMENTWISE_AFFINE:
104
+ w_ln_tile = tl.load(w_ln_ptrs).to(tl.float32)
105
+ b_ln_tile = tl.load(b_ln_ptrs).to(tl.float32)
106
+ z_tile = z_tile * w_ln_tile + b_ln_tile
107
+
108
+ if IS_TRAINING:
109
+ tl.store(z_norm_ptrs, z_tile, mask=mask_v[:, None])
110
+
111
+ w_tile = tl.load(w_proj_ptrs, mask=mask_head[None, :], other=0.0).to(tl.float32)
112
+
113
+ acc = tl.dot(z_tile, w_tile, acc, input_precision="tf32x3")
114
+
115
+ z_ptrs += TILE_K
116
+ w_proj_ptrs += TILE_K
117
+ if ELEMENTWISE_AFFINE:
118
+ w_ln_ptrs += TILE_K
119
+ b_ln_ptrs += TILE_K
120
+ if IS_TRAINING:
121
+ z_norm_ptrs += TILE_K
122
+
123
+ if HAS_BIAS:
124
+ b_proj_ptrs = b_proj_z_ptr + offs_head
125
+ b_proj_tile = tl.load(b_proj_ptrs, mask=mask_head, other=0.0).to(tl.float32)
126
+ acc += b_proj_tile[None, :]
127
+
128
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
129
+ mask_v = offs_v < V
130
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
131
+ mask_head = offs_head < NUM_HEADS
132
+
133
+ out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
134
+ out_mask_ptrs += offs_u * V
135
+ out_mask_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
136
+ mask_o = mask_head[None, :] & mask_v[:, None]
137
+
138
+ if MASK_WITH_MULTIPLICITY:
139
+ mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
140
+
141
+ for _ in range(multiplicity):
142
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
143
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
144
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
145
+
146
+ out_mask_ptrs += NUM_HEADS * stride_uv
147
+ mask_ptrs += V
148
+
149
+ else:
150
+ mask_ptrs = mask_ptr + batch_idx * V + offs_v
151
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
152
+
153
+ for _ in range(multiplicity):
154
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
155
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
156
+
157
+ out_mask_ptrs += NUM_HEADS * stride_uv
158
+ mask_ptrs += V
159
+
160
+
161
+ @triton.jit
162
+ def pair_bias_linear_mask_forward_kernel(
163
+ z_ptr,
164
+ mask_ptr,
165
+ w_proj_z_ptr,
166
+ b_proj_z_ptr,
167
+ U,
168
+ V,
169
+ multiplicity,
170
+ out_mask_ptr,
171
+ TILE_V: tl.constexpr,
172
+ TILE_K: tl.constexpr,
173
+ NUM_HEADS: tl.constexpr,
174
+ NUM_HEADS_PER_BLK: tl.constexpr,
175
+ DIM_Z: tl.constexpr,
176
+ INF: tl.constexpr,
177
+ HAS_BIAS: tl.constexpr,
178
+ MASK_WITH_MULTIPLICITY: tl.constexpr,
179
+ ):
180
+ # prepare single mask
181
+ # z: B x U x V x D -> z' B x H x U x V
182
+ # mask: B x V
183
+ # out_mask = z' + (1 - mask) * inf
184
+
185
+ pid_v = tl.program_id(0)
186
+ pid_u = tl.program_id(1)
187
+ head_batch_idx = tl.program_id(2)
188
+ NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
189
+ batch_idx = head_batch_idx // NUM_HEAD_BLKS
190
+ head_idx = head_batch_idx % NUM_HEAD_BLKS
191
+
192
+ stride_vz = V * DIM_Z
193
+ stride_uv = U * V
194
+ stride_uvz = U * V * DIM_Z
195
+
196
+ offs_u = pid_u
197
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
198
+ offs_k = tl.arange(0, TILE_K)
199
+ mask_v = offs_v < V
200
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
201
+ mask_head = offs_head < NUM_HEADS
202
+
203
+ z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
204
+ z_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
205
+
206
+ w_ptrs = w_proj_z_ptr + (offs_head[None, :] * DIM_Z + offs_k[:, None])
207
+
208
+ acc = tl.zeros((TILE_V, NUM_HEADS_PER_BLK), dtype=tl.float32)
209
+
210
+ for _ in range(0, DIM_Z // TILE_K):
211
+ z_tile = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
212
+ w_tile = tl.load(w_ptrs, mask=mask_head[None, :], other=0.0).to(tl.float32)
213
+
214
+ acc = tl.dot(z_tile, w_tile, acc, input_precision="tf32x3")
215
+
216
+ z_ptrs += TILE_K
217
+ w_ptrs += TILE_K
218
+
219
+ if HAS_BIAS:
220
+ b_proj_ptrs = b_proj_z_ptr + offs_head
221
+ b_proj_tile = tl.load(b_proj_ptrs, mask=mask_head, other=0.0)
222
+ acc += b_proj_tile[None, :]
223
+
224
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
225
+ mask_v = offs_v < V
226
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
227
+ mask_head = offs_head < NUM_HEADS
228
+
229
+ out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
230
+ out_mask_ptrs += offs_u * V
231
+ out_mask_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
232
+ mask_o = mask_head[None, :] & mask_v[:, None]
233
+
234
+ if MASK_WITH_MULTIPLICITY:
235
+ mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
236
+
237
+ for _ in range(multiplicity):
238
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
239
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
240
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
241
+
242
+ out_mask_ptrs += NUM_HEADS * stride_uv
243
+ mask_ptrs += V
244
+
245
+ else:
246
+ mask_ptrs = mask_ptr + batch_idx * V + offs_v
247
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
248
+
249
+ for _ in range(multiplicity):
250
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
251
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
252
+
253
+ out_mask_ptrs += NUM_HEADS * stride_uv
254
+ mask_ptrs += V
255
+
256
+
257
+ @triton.jit
258
+ def pair_bias_mask_forward_kernel(
259
+ z_ptr,
260
+ mask_ptr,
261
+ U,
262
+ V,
263
+ multiplicity,
264
+ out_mask_ptr,
265
+ TILE_V: tl.constexpr,
266
+ NUM_HEADS: tl.constexpr,
267
+ NUM_HEADS_PER_BLK: tl.constexpr,
268
+ INF: tl.constexpr,
269
+ MASK_WITH_MULTIPLICITY: tl.constexpr,
270
+ ):
271
+ # prepare single mask
272
+ # z: B x U x V x H -> z' B x H x U x V
273
+ # mask: B x V
274
+ # out_mask = z' + (1 - mask) * inf
275
+
276
+ pid_v = tl.program_id(0)
277
+ pid_u = tl.program_id(1)
278
+ head_batch_idx = tl.program_id(2)
279
+ NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
280
+ batch_idx = head_batch_idx // NUM_HEAD_BLKS
281
+ head_idx = head_batch_idx % NUM_HEAD_BLKS
282
+
283
+ stride_nh = V * NUM_HEADS
284
+ stride_uv = U * V
285
+ stride_uvh = U * V * NUM_HEADS
286
+
287
+ offs_u = pid_u
288
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
289
+ mask_v = offs_v < V
290
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
291
+ mask_head = offs_head < NUM_HEADS
292
+
293
+ z_ptrs = z_ptr + batch_idx * stride_uvh + offs_u * stride_nh
294
+ z_ptrs += offs_v[:, None] * NUM_HEADS + offs_head[None, :]
295
+
296
+ mask_zo = mask_v[:, None] & mask_head[None, :]
297
+
298
+ z_tile = tl.load(z_ptrs, mask=mask_zo, other=0.0).to(tl.float32)
299
+
300
+ out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
301
+ out_mask_ptrs += offs_u * V
302
+ out_mask_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
303
+
304
+ if MASK_WITH_MULTIPLICITY:
305
+ mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
306
+
307
+ for _ in range(multiplicity):
308
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
309
+ out_tile = z_tile + (1.0 - mask_tile[:, None]) * (-INF)
310
+ tl.store(out_mask_ptrs, out_tile, mask=mask_zo)
311
+
312
+ out_mask_ptrs += NUM_HEADS * stride_uv
313
+ mask_ptrs += V
314
+
315
+ else:
316
+ mask_ptrs = mask_ptr + batch_idx * V + offs_v
317
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
318
+
319
+ for _ in range(multiplicity):
320
+ out_tile = z_tile + (1.0 - mask_tile[:, None]) * (-INF)
321
+ tl.store(out_mask_ptrs, out_tile, mask=mask_zo)
322
+
323
+ out_mask_ptrs += NUM_HEADS * stride_uv
324
+ mask_ptrs += V