cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuequivariance_ops/VERSION +1 -0
- cuequivariance_ops/__init__.py +42 -0
- cuequivariance_ops/_version.py +20 -0
- cuequivariance_ops/common/common.hpp +98 -0
- cuequivariance_ops/common/cudart.hpp +286 -0
- cuequivariance_ops/common/error.hpp +66 -0
- cuequivariance_ops/common/error_raft.hpp +323 -0
- cuequivariance_ops/common/nvtx.hpp +29 -0
- cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
- cuequivariance_ops/equivariance/dtypes.hh +65 -0
- cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
- cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
- cuequivariance_ops/equivariance/run_fmha.h +192 -0
- cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
- cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
- cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
- cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
- cuequivariance_ops/gpu_timing_kernels.hh +42 -0
- cuequivariance_ops/lib/libcue_ops.so +0 -0
- cuequivariance_ops/sleep.hh +40 -0
- cuequivariance_ops/triton/__init__.py +66 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
- cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
- cuequivariance_ops/triton/cache_manager.py +336 -0
- cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
- cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
- cuequivariance_ops/triton/pair_bias.py +365 -0
- cuequivariance_ops/triton/tuning_decorator.py +188 -0
- cuequivariance_ops/triton/utils.py +29 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
- cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
import triton
|
|
12
|
+
import triton.language as tl
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@triton.jit
|
|
16
|
+
def pair_bias_norm_linear_mask_forward_kernel(
|
|
17
|
+
z_ptr,
|
|
18
|
+
mask_ptr,
|
|
19
|
+
w_proj_z_ptr,
|
|
20
|
+
b_proj_z_ptr,
|
|
21
|
+
w_ln_ptr,
|
|
22
|
+
b_ln_ptr,
|
|
23
|
+
U,
|
|
24
|
+
V,
|
|
25
|
+
multiplicity,
|
|
26
|
+
out_mask_ptr,
|
|
27
|
+
z_norm_ptr,
|
|
28
|
+
z_proj_ptr,
|
|
29
|
+
mean_ptr,
|
|
30
|
+
rstd_ptr,
|
|
31
|
+
TILE_V: tl.constexpr,
|
|
32
|
+
TILE_K: tl.constexpr,
|
|
33
|
+
NUM_HEADS: tl.constexpr,
|
|
34
|
+
NUM_HEADS_PER_BLK: tl.constexpr,
|
|
35
|
+
DIM_Z: tl.constexpr,
|
|
36
|
+
INF: tl.constexpr,
|
|
37
|
+
EPS: tl.constexpr,
|
|
38
|
+
ELEMENTWISE_AFFINE: tl.constexpr,
|
|
39
|
+
IS_TRAINING: tl.constexpr,
|
|
40
|
+
HAS_BIAS: tl.constexpr,
|
|
41
|
+
MASK_WITH_MULTIPLICITY: tl.constexpr,
|
|
42
|
+
CACHE_Z_PROJ: tl.constexpr = False,
|
|
43
|
+
NEEDS_INT64: tl.constexpr = True,
|
|
44
|
+
):
|
|
45
|
+
# prepare single mask
|
|
46
|
+
# z: B x U x V x D -> z' B x H x U x V
|
|
47
|
+
# mask: B x V
|
|
48
|
+
# out_mask = z' + (1 - mask) * inf
|
|
49
|
+
|
|
50
|
+
pid_v = tl.program_id(0)
|
|
51
|
+
pid_u = tl.program_id(1)
|
|
52
|
+
|
|
53
|
+
if NEEDS_INT64:
|
|
54
|
+
pid_u = tl.cast(pid_u, tl.int64)
|
|
55
|
+
pid_v = tl.cast(pid_v, tl.int64)
|
|
56
|
+
U = tl.cast(U, tl.int64)
|
|
57
|
+
V = tl.cast(V, tl.int64)
|
|
58
|
+
|
|
59
|
+
head_batch_idx = tl.program_id(2)
|
|
60
|
+
NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
|
|
61
|
+
batch_idx = head_batch_idx // NUM_HEAD_BLKS
|
|
62
|
+
head_idx = head_batch_idx % NUM_HEAD_BLKS
|
|
63
|
+
|
|
64
|
+
stride_vz = V * DIM_Z
|
|
65
|
+
stride_uv = U * V
|
|
66
|
+
stride_uvz = U * V * DIM_Z
|
|
67
|
+
|
|
68
|
+
offs_u = pid_u
|
|
69
|
+
offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
|
|
70
|
+
offs_k = tl.arange(0, TILE_K)
|
|
71
|
+
offs_z = tl.arange(0, DIM_Z)
|
|
72
|
+
mask_v = offs_v < V
|
|
73
|
+
offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
|
|
74
|
+
mask_head = offs_head < NUM_HEADS
|
|
75
|
+
|
|
76
|
+
z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
|
|
77
|
+
z_ptrs += offs_v[:, None] * DIM_Z + offs_z[None, :]
|
|
78
|
+
|
|
79
|
+
z_tile_full = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
|
|
80
|
+
|
|
81
|
+
mean = tl.sum(z_tile_full, axis=1) / DIM_Z
|
|
82
|
+
rstd = z_tile_full - mean[:, None]
|
|
83
|
+
rstd = rstd * rstd
|
|
84
|
+
rstd = tl.sum(rstd, axis=1) / DIM_Z
|
|
85
|
+
rstd = tl.rsqrt(rstd + EPS)
|
|
86
|
+
|
|
87
|
+
if IS_TRAINING:
|
|
88
|
+
mean_ptrs = mean_ptr + batch_idx * stride_uv
|
|
89
|
+
mean_ptrs += offs_u * V + offs_v
|
|
90
|
+
tl.store(mean_ptrs, mean, mask=mask_v)
|
|
91
|
+
|
|
92
|
+
rstd_ptrs = rstd_ptr + batch_idx * stride_uv
|
|
93
|
+
rstd_ptrs += offs_u * V + offs_v
|
|
94
|
+
tl.store(rstd_ptrs, rstd, mask=mask_v)
|
|
95
|
+
|
|
96
|
+
z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
|
|
97
|
+
z_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
|
|
98
|
+
w_ln_ptrs = w_ln_ptr + offs_k
|
|
99
|
+
b_ln_ptrs = b_ln_ptr + offs_k
|
|
100
|
+
w_proj_ptrs = w_proj_z_ptr + (offs_head[None, :] * DIM_Z + offs_k[:, None])
|
|
101
|
+
|
|
102
|
+
if IS_TRAINING:
|
|
103
|
+
z_norm_ptrs = z_norm_ptr + batch_idx * stride_uvz + offs_u * stride_vz
|
|
104
|
+
z_norm_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
|
|
105
|
+
|
|
106
|
+
num_tiles_k = DIM_Z // TILE_K
|
|
107
|
+
acc = tl.zeros((TILE_V, NUM_HEADS_PER_BLK), dtype=tl.float32)
|
|
108
|
+
|
|
109
|
+
for _ in range(0, num_tiles_k):
|
|
110
|
+
z_tile = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
|
|
111
|
+
z_tile = (z_tile - mean[:, None]) * rstd[:, None]
|
|
112
|
+
|
|
113
|
+
if ELEMENTWISE_AFFINE:
|
|
114
|
+
w_ln_tile = tl.load(w_ln_ptrs).to(tl.float32)
|
|
115
|
+
b_ln_tile = tl.load(b_ln_ptrs).to(tl.float32)
|
|
116
|
+
z_tile = z_tile * w_ln_tile + b_ln_tile
|
|
117
|
+
|
|
118
|
+
if IS_TRAINING:
|
|
119
|
+
tl.store(z_norm_ptrs, z_tile, mask=mask_v[:, None])
|
|
120
|
+
|
|
121
|
+
w_tile = tl.load(w_proj_ptrs, mask=mask_head[None, :], other=0.0).to(tl.float32)
|
|
122
|
+
|
|
123
|
+
acc = tl.dot(z_tile, w_tile, acc, input_precision="tf32x3")
|
|
124
|
+
|
|
125
|
+
z_ptrs += TILE_K
|
|
126
|
+
w_proj_ptrs += TILE_K
|
|
127
|
+
if ELEMENTWISE_AFFINE:
|
|
128
|
+
w_ln_ptrs += TILE_K
|
|
129
|
+
b_ln_ptrs += TILE_K
|
|
130
|
+
if IS_TRAINING:
|
|
131
|
+
z_norm_ptrs += TILE_K
|
|
132
|
+
|
|
133
|
+
if HAS_BIAS:
|
|
134
|
+
b_proj_ptrs = b_proj_z_ptr + offs_head
|
|
135
|
+
b_proj_tile = tl.load(b_proj_ptrs, mask=mask_head, other=0.0).to(tl.float32)
|
|
136
|
+
acc += b_proj_tile[None, :]
|
|
137
|
+
|
|
138
|
+
offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
|
|
139
|
+
mask_v = offs_v < V
|
|
140
|
+
offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
|
|
141
|
+
mask_head = offs_head < NUM_HEADS
|
|
142
|
+
|
|
143
|
+
if CACHE_Z_PROJ:
|
|
144
|
+
z_proj_ptrs = z_proj_ptr + batch_idx * NUM_HEADS * stride_uv
|
|
145
|
+
z_proj_ptrs += offs_u * V
|
|
146
|
+
z_proj_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
|
|
147
|
+
mask_z = mask_head[None, :] & mask_v[:, None]
|
|
148
|
+
tl.store(z_proj_ptrs, acc, mask=mask_z)
|
|
149
|
+
|
|
150
|
+
out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
|
|
151
|
+
out_mask_ptrs += offs_u * V
|
|
152
|
+
out_mask_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
|
|
153
|
+
mask_o = mask_head[None, :] & mask_v[:, None]
|
|
154
|
+
|
|
155
|
+
if MASK_WITH_MULTIPLICITY:
|
|
156
|
+
mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
|
|
157
|
+
|
|
158
|
+
for _ in range(multiplicity):
|
|
159
|
+
mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
|
|
160
|
+
out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
|
|
161
|
+
tl.store(out_mask_ptrs, out_tile, mask=mask_o)
|
|
162
|
+
|
|
163
|
+
out_mask_ptrs += NUM_HEADS * stride_uv
|
|
164
|
+
mask_ptrs += V
|
|
165
|
+
|
|
166
|
+
else:
|
|
167
|
+
mask_ptrs = mask_ptr + batch_idx * V + offs_v
|
|
168
|
+
mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
|
|
169
|
+
|
|
170
|
+
for _ in range(multiplicity):
|
|
171
|
+
out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
|
|
172
|
+
tl.store(out_mask_ptrs, out_tile, mask=mask_o)
|
|
173
|
+
|
|
174
|
+
out_mask_ptrs += NUM_HEADS * stride_uv
|
|
175
|
+
mask_ptrs += V
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@triton.jit
|
|
179
|
+
def pair_bias_linear_mask_forward_kernel(
|
|
180
|
+
z_ptr,
|
|
181
|
+
mask_ptr,
|
|
182
|
+
w_proj_z_ptr,
|
|
183
|
+
b_proj_z_ptr,
|
|
184
|
+
U,
|
|
185
|
+
V,
|
|
186
|
+
multiplicity,
|
|
187
|
+
out_mask_ptr,
|
|
188
|
+
z_proj_ptr,
|
|
189
|
+
TILE_V: tl.constexpr,
|
|
190
|
+
TILE_K: tl.constexpr,
|
|
191
|
+
NUM_HEADS: tl.constexpr,
|
|
192
|
+
NUM_HEADS_PER_BLK: tl.constexpr,
|
|
193
|
+
DIM_Z: tl.constexpr,
|
|
194
|
+
INF: tl.constexpr,
|
|
195
|
+
HAS_BIAS: tl.constexpr,
|
|
196
|
+
MASK_WITH_MULTIPLICITY: tl.constexpr,
|
|
197
|
+
CACHE_Z_PROJ: tl.constexpr = False,
|
|
198
|
+
NEEDS_INT64: tl.constexpr = True,
|
|
199
|
+
):
|
|
200
|
+
# prepare single mask
|
|
201
|
+
# z: B x U x V x D -> z' B x H x U x V
|
|
202
|
+
# mask: B x V
|
|
203
|
+
# out_mask = z' + (1 - mask) * inf
|
|
204
|
+
|
|
205
|
+
pid_v = tl.program_id(0)
|
|
206
|
+
pid_u = tl.program_id(1)
|
|
207
|
+
|
|
208
|
+
if NEEDS_INT64:
|
|
209
|
+
pid_u = tl.cast(pid_u, tl.int64)
|
|
210
|
+
pid_v = tl.cast(pid_v, tl.int64)
|
|
211
|
+
U = tl.cast(U, tl.int64)
|
|
212
|
+
V = tl.cast(V, tl.int64)
|
|
213
|
+
|
|
214
|
+
head_batch_idx = tl.program_id(2)
|
|
215
|
+
NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
|
|
216
|
+
batch_idx = head_batch_idx // NUM_HEAD_BLKS
|
|
217
|
+
head_idx = head_batch_idx % NUM_HEAD_BLKS
|
|
218
|
+
|
|
219
|
+
stride_vz = V * DIM_Z
|
|
220
|
+
stride_uv = U * V
|
|
221
|
+
stride_uvz = U * V * DIM_Z
|
|
222
|
+
|
|
223
|
+
offs_u = pid_u
|
|
224
|
+
offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
|
|
225
|
+
offs_k = tl.arange(0, TILE_K)
|
|
226
|
+
mask_v = offs_v < V
|
|
227
|
+
offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
|
|
228
|
+
mask_head = offs_head < NUM_HEADS
|
|
229
|
+
|
|
230
|
+
z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
|
|
231
|
+
z_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
|
|
232
|
+
|
|
233
|
+
w_ptrs = w_proj_z_ptr + (offs_head[None, :] * DIM_Z + offs_k[:, None])
|
|
234
|
+
|
|
235
|
+
acc = tl.zeros((TILE_V, NUM_HEADS_PER_BLK), dtype=tl.float32)
|
|
236
|
+
|
|
237
|
+
for _ in range(0, DIM_Z // TILE_K):
|
|
238
|
+
z_tile = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
|
|
239
|
+
w_tile = tl.load(w_ptrs, mask=mask_head[None, :], other=0.0).to(tl.float32)
|
|
240
|
+
|
|
241
|
+
acc = tl.dot(z_tile, w_tile, acc, input_precision="tf32x3")
|
|
242
|
+
|
|
243
|
+
z_ptrs += TILE_K
|
|
244
|
+
w_ptrs += TILE_K
|
|
245
|
+
|
|
246
|
+
if HAS_BIAS:
|
|
247
|
+
b_proj_ptrs = b_proj_z_ptr + offs_head
|
|
248
|
+
b_proj_tile = tl.load(b_proj_ptrs, mask=mask_head, other=0.0)
|
|
249
|
+
acc += b_proj_tile[None, :]
|
|
250
|
+
|
|
251
|
+
offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
|
|
252
|
+
mask_v = offs_v < V
|
|
253
|
+
offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
|
|
254
|
+
mask_head = offs_head < NUM_HEADS
|
|
255
|
+
|
|
256
|
+
if CACHE_Z_PROJ:
|
|
257
|
+
z_proj_ptrs = z_proj_ptr + batch_idx * NUM_HEADS * stride_uv
|
|
258
|
+
z_proj_ptrs += offs_u * V
|
|
259
|
+
z_proj_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
|
|
260
|
+
mask_z = mask_head[None, :] & mask_v[:, None]
|
|
261
|
+
tl.store(z_proj_ptrs, acc, mask=mask_z)
|
|
262
|
+
|
|
263
|
+
out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
|
|
264
|
+
out_mask_ptrs += offs_u * V
|
|
265
|
+
out_mask_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
|
|
266
|
+
mask_o = mask_head[None, :] & mask_v[:, None]
|
|
267
|
+
|
|
268
|
+
if MASK_WITH_MULTIPLICITY:
|
|
269
|
+
mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
|
|
270
|
+
|
|
271
|
+
for _ in range(multiplicity):
|
|
272
|
+
mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
|
|
273
|
+
out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
|
|
274
|
+
tl.store(out_mask_ptrs, out_tile, mask=mask_o)
|
|
275
|
+
|
|
276
|
+
out_mask_ptrs += NUM_HEADS * stride_uv
|
|
277
|
+
mask_ptrs += V
|
|
278
|
+
|
|
279
|
+
else:
|
|
280
|
+
mask_ptrs = mask_ptr + batch_idx * V + offs_v
|
|
281
|
+
mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
|
|
282
|
+
|
|
283
|
+
for _ in range(multiplicity):
|
|
284
|
+
out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
|
|
285
|
+
tl.store(out_mask_ptrs, out_tile, mask=mask_o)
|
|
286
|
+
|
|
287
|
+
out_mask_ptrs += NUM_HEADS * stride_uv
|
|
288
|
+
mask_ptrs += V
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@triton.jit
|
|
292
|
+
def pair_bias_mask_forward_kernel(
|
|
293
|
+
z_proj_ptr,
|
|
294
|
+
mask_ptr,
|
|
295
|
+
U,
|
|
296
|
+
V,
|
|
297
|
+
multiplicity,
|
|
298
|
+
out_mask_ptr,
|
|
299
|
+
TILE_V: tl.constexpr,
|
|
300
|
+
NUM_HEADS: tl.constexpr,
|
|
301
|
+
NUM_HEADS_PER_BLK: tl.constexpr,
|
|
302
|
+
INF: tl.constexpr,
|
|
303
|
+
MASK_WITH_MULTIPLICITY: tl.constexpr,
|
|
304
|
+
NEEDS_INT64: tl.constexpr = True,
|
|
305
|
+
):
|
|
306
|
+
# prepare single mask
|
|
307
|
+
# z: z' B x H x U x V
|
|
308
|
+
# mask: B x V
|
|
309
|
+
# out_mask = z' + (1 - mask) * inf
|
|
310
|
+
|
|
311
|
+
pid_v = tl.program_id(0)
|
|
312
|
+
pid_u = tl.program_id(1)
|
|
313
|
+
|
|
314
|
+
if NEEDS_INT64:
|
|
315
|
+
pid_u = tl.cast(pid_u, tl.int64)
|
|
316
|
+
pid_v = tl.cast(pid_v, tl.int64)
|
|
317
|
+
U = tl.cast(U, tl.int64)
|
|
318
|
+
V = tl.cast(V, tl.int64)
|
|
319
|
+
|
|
320
|
+
head_batch_idx = tl.program_id(2)
|
|
321
|
+
NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
|
|
322
|
+
batch_idx = head_batch_idx // NUM_HEAD_BLKS
|
|
323
|
+
head_idx = head_batch_idx % NUM_HEAD_BLKS
|
|
324
|
+
|
|
325
|
+
stride_uv = U * V
|
|
326
|
+
stride_uvh = U * V * NUM_HEADS
|
|
327
|
+
|
|
328
|
+
offs_u = pid_u
|
|
329
|
+
offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
|
|
330
|
+
mask_v = offs_v < V
|
|
331
|
+
offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
|
|
332
|
+
mask_head = offs_head < NUM_HEADS
|
|
333
|
+
|
|
334
|
+
z_proj_ptrs = z_proj_ptr + batch_idx * stride_uvh + offs_u * V
|
|
335
|
+
z_proj_ptrs += offs_head[:, None] * stride_uv + offs_v[None, :]
|
|
336
|
+
|
|
337
|
+
mask_zo = mask_v[None, :] & mask_head[:, None]
|
|
338
|
+
|
|
339
|
+
z_proj_tile = tl.load(z_proj_ptrs, mask=mask_zo, other=0.0).to(tl.float32)
|
|
340
|
+
|
|
341
|
+
out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
|
|
342
|
+
out_mask_ptrs += offs_u * V
|
|
343
|
+
out_mask_ptrs += offs_head[:, None] * stride_uv + offs_v[None, :]
|
|
344
|
+
|
|
345
|
+
if MASK_WITH_MULTIPLICITY:
|
|
346
|
+
mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
|
|
347
|
+
|
|
348
|
+
for _ in range(multiplicity):
|
|
349
|
+
mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
|
|
350
|
+
out_tile = z_proj_tile + (1.0 - mask_tile[None, :]) * (-INF)
|
|
351
|
+
tl.store(out_mask_ptrs, out_tile, mask=mask_zo)
|
|
352
|
+
|
|
353
|
+
out_mask_ptrs += NUM_HEADS * stride_uv
|
|
354
|
+
mask_ptrs += V
|
|
355
|
+
|
|
356
|
+
else:
|
|
357
|
+
mask_ptrs = mask_ptr + batch_idx * V + offs_v
|
|
358
|
+
mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
|
|
359
|
+
|
|
360
|
+
for _ in range(multiplicity):
|
|
361
|
+
out_tile = z_proj_tile + (1.0 - mask_tile[None, :]) * (-INF)
|
|
362
|
+
tl.store(out_mask_ptrs, out_tile, mask=mask_zo)
|
|
363
|
+
|
|
364
|
+
out_mask_ptrs += NUM_HEADS * stride_uv
|
|
365
|
+
mask_ptrs += V
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
import gc
|
|
12
|
+
import inspect
|
|
13
|
+
import logging # Added logging import
|
|
14
|
+
from typing import Any, Callable
|
|
15
|
+
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from .cache_manager import get_cache_manager
|
|
19
|
+
|
|
20
|
+
# import torch
|
|
21
|
+
|
|
22
|
+
# Configure logging
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def input_to_key_default(**args) -> str:
|
|
27
|
+
key_parts = []
|
|
28
|
+
for arg in args:
|
|
29
|
+
if hasattr(arg, "shape") and hasattr(arg, "dtype"):
|
|
30
|
+
key_parts.append(f"{list(arg.shape)}_{arg.dtype}")
|
|
31
|
+
elif isinstance(arg, bool):
|
|
32
|
+
key_parts.append("True" if arg else "False")
|
|
33
|
+
elif isinstance(arg, str):
|
|
34
|
+
key_parts.append(arg)
|
|
35
|
+
else:
|
|
36
|
+
key_parts.append(str(arg.__class__.__name__))
|
|
37
|
+
|
|
38
|
+
return "_".join(key_parts)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def combine_all_kwargs(
|
|
42
|
+
fn: Callable,
|
|
43
|
+
args: tuple,
|
|
44
|
+
kwargs: dict[str, Any],
|
|
45
|
+
) -> dict[str, Any]:
|
|
46
|
+
# Get the function signature
|
|
47
|
+
sig = inspect.signature(fn)
|
|
48
|
+
params = sig.parameters
|
|
49
|
+
param_names = list(params.keys())
|
|
50
|
+
|
|
51
|
+
# Create dictionary of default values
|
|
52
|
+
defaults = {
|
|
53
|
+
name: param.default
|
|
54
|
+
for name, param in params.items()
|
|
55
|
+
if param.default is not inspect.Parameter.empty
|
|
56
|
+
}
|
|
57
|
+
# Create dictionary mapping positional args to parameter names
|
|
58
|
+
args_as_kwargs = {
|
|
59
|
+
param_names[i]: args[i] for i in range(min(len(args), len(param_names)))
|
|
60
|
+
}
|
|
61
|
+
# Create combined dictionary of all parameters
|
|
62
|
+
all_kwargs = defaults.copy() # Start with defaults
|
|
63
|
+
all_kwargs.update(args_as_kwargs) # Override with positional args
|
|
64
|
+
all_kwargs.update(kwargs) # Override with explicit kwargs
|
|
65
|
+
|
|
66
|
+
return all_kwargs
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def autotune_aot(
|
|
70
|
+
input_generator: Callable,
|
|
71
|
+
input_to_key: Callable | None,
|
|
72
|
+
input_configs: list[dict[str, Any]],
|
|
73
|
+
tunable_configs: list[dict[str, Any]],
|
|
74
|
+
prune_configs_fn: Callable[
|
|
75
|
+
[list[dict[str, Any]], dict[str, Any]], list[dict[str, Any]]
|
|
76
|
+
]
|
|
77
|
+
| None,
|
|
78
|
+
run_decoy: Callable[[Callable, dict[str, Any]], None],
|
|
79
|
+
run_bench: Callable[[Callable, dict[str, Any]], float],
|
|
80
|
+
) -> None:
|
|
81
|
+
def decorator(fn: Callable) -> Callable:
|
|
82
|
+
def wrapper(*args, **kwargs):
|
|
83
|
+
all_kwargs = combine_all_kwargs(fn, args, kwargs)
|
|
84
|
+
nonlocal input_to_key
|
|
85
|
+
nonlocal input_configs
|
|
86
|
+
|
|
87
|
+
if input_to_key is None:
|
|
88
|
+
input_to_key = input_to_key_default
|
|
89
|
+
|
|
90
|
+
# Check if the function is already cached
|
|
91
|
+
function_key = fn.__name__
|
|
92
|
+
input_key = input_to_key(**all_kwargs)
|
|
93
|
+
cache_manager = get_cache_manager()
|
|
94
|
+
best_cached_config = cache_manager.get(function_key, input_key)
|
|
95
|
+
|
|
96
|
+
aot_mode = cache_manager.aot_mode
|
|
97
|
+
|
|
98
|
+
if best_cached_config is None and aot_mode is not None:
|
|
99
|
+
# start autotuning process
|
|
100
|
+
# input_configs = input_configs + [None]
|
|
101
|
+
if aot_mode == "ONDEMAND":
|
|
102
|
+
input_configs = [None]
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
# Initialize the progress bar
|
|
106
|
+
progress_bar = tqdm(
|
|
107
|
+
input_configs, desc="Autotuning Progress", unit="config"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
for input_config in progress_bar:
|
|
111
|
+
# generate input based on the config
|
|
112
|
+
input_data = (
|
|
113
|
+
input_generator(**input_config)
|
|
114
|
+
if input_config is not None
|
|
115
|
+
else all_kwargs
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Make a copy of all_kwargs to avoid modifying the original
|
|
119
|
+
current_kwargs = all_kwargs.copy()
|
|
120
|
+
current_kwargs.update(input_data)
|
|
121
|
+
current_input_key = input_to_key(**current_kwargs)
|
|
122
|
+
|
|
123
|
+
best_cached_config = cache_manager.get(
|
|
124
|
+
function_key, current_input_key
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if best_cached_config is not None:
|
|
128
|
+
continue
|
|
129
|
+
|
|
130
|
+
# print(f"Running for key: {current_input_key}")
|
|
131
|
+
|
|
132
|
+
# prune the tunable configs based on the all_kwargs
|
|
133
|
+
pruned_tunable_configs = (
|
|
134
|
+
prune_configs_fn(tunable_configs, **all_kwargs)
|
|
135
|
+
if prune_configs_fn is not None
|
|
136
|
+
else tunable_configs
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
best_config = None
|
|
140
|
+
best_time = float("inf")
|
|
141
|
+
working_config = []
|
|
142
|
+
for tunable in pruned_tunable_configs:
|
|
143
|
+
try:
|
|
144
|
+
current_kwargs.update(tunable)
|
|
145
|
+
run_decoy(fn, current_kwargs)
|
|
146
|
+
working_config.append(tunable)
|
|
147
|
+
except Exception:
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
if not working_config:
|
|
151
|
+
logger.warning(
|
|
152
|
+
f"No valid configurations found for input: {current_input_key}"
|
|
153
|
+
)
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
for tunable in working_config:
|
|
157
|
+
current_kwargs.update(tunable)
|
|
158
|
+
elapse = run_bench(fn, current_kwargs)
|
|
159
|
+
if elapse < best_time:
|
|
160
|
+
best_time = elapse
|
|
161
|
+
best_config = tunable
|
|
162
|
+
|
|
163
|
+
cache_manager.set(
|
|
164
|
+
function_key,
|
|
165
|
+
current_input_key,
|
|
166
|
+
{"config": best_config, "time": best_time},
|
|
167
|
+
)
|
|
168
|
+
current_kwargs = None
|
|
169
|
+
input_data = None
|
|
170
|
+
gc.collect()
|
|
171
|
+
# torch.cuda.empty_cache()
|
|
172
|
+
if (progress_bar.n % 1000) == 1:
|
|
173
|
+
cache_manager.save_cache(function_key)
|
|
174
|
+
cache_manager.save_cache(function_key)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
print(f"Stopping autotuning due to error: {e}")
|
|
177
|
+
|
|
178
|
+
# After tuning, try to get the best config
|
|
179
|
+
best_cached_config = cache_manager.get(function_key, input_key)
|
|
180
|
+
|
|
181
|
+
if best_cached_config is not None:
|
|
182
|
+
all_kwargs.update(best_cached_config["config"])
|
|
183
|
+
|
|
184
|
+
return fn(**all_kwargs)
|
|
185
|
+
|
|
186
|
+
return wrapper
|
|
187
|
+
|
|
188
|
+
return decorator
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
|
3
|
+
#
|
|
4
|
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
|
5
|
+
# property and proprietary rights in and to this material, related
|
|
6
|
+
# documentation and any modifications thereto. Any use, reproduction,
|
|
7
|
+
# disclosure or distribution of this material and related documentation
|
|
8
|
+
# without an express license agreement from NVIDIA CORPORATION or
|
|
9
|
+
# its affiliates is strictly prohibited.
|
|
10
|
+
|
|
11
|
+
import enum
|
|
12
|
+
|
|
13
|
+
import triton
|
|
14
|
+
import triton.language as tl
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Precision(enum.Enum):
|
|
18
|
+
DEFAULT = 0
|
|
19
|
+
TF32 = 1
|
|
20
|
+
TF32x3 = 2
|
|
21
|
+
IEEE = 3
|
|
22
|
+
NONE = -1
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@triton.jit
|
|
26
|
+
def cvt_tf32_rn(x: tl.tensor) -> tl.tensor:
|
|
27
|
+
return tl.inline_asm_elementwise(
|
|
28
|
+
"cvt.rna.tf32.f32 $0, $1;", "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1
|
|
29
|
+
)
|