cuequivariance-ops-cu12 0.8.1__py3-none-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. cuequivariance_ops/VERSION +1 -0
  2. cuequivariance_ops/__init__.py +42 -0
  3. cuequivariance_ops/_version.py +20 -0
  4. cuequivariance_ops/common/common.hpp +98 -0
  5. cuequivariance_ops/common/cudart.hpp +286 -0
  6. cuequivariance_ops/common/error.hpp +66 -0
  7. cuequivariance_ops/common/error_raft.hpp +323 -0
  8. cuequivariance_ops/common/nvtx.hpp +29 -0
  9. cuequivariance_ops/equivariance/batch_dimension.hh +15 -0
  10. cuequivariance_ops/equivariance/dtypes.hh +65 -0
  11. cuequivariance_ops/equivariance/fused_tensor_product.cuh +297 -0
  12. cuequivariance_ops/equivariance/indexed_linear.hh +41 -0
  13. cuequivariance_ops/equivariance/run_fmha.h +192 -0
  14. cuequivariance_ops/equivariance/run_fmha_cudafree.h +176 -0
  15. cuequivariance_ops/equivariance/run_fmha_sm100.h +135 -0
  16. cuequivariance_ops/equivariance/segmented_transpose.cuh +40 -0
  17. cuequivariance_ops/equivariance/tensor_product_uniform_1d_jit.hh +38 -0
  18. cuequivariance_ops/gpu_timing_kernels.hh +42 -0
  19. cuequivariance_ops/lib/libcue_ops.so +0 -0
  20. cuequivariance_ops/sleep.hh +40 -0
  21. cuequivariance_ops/triton/__init__.py +66 -0
  22. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.10.0.json +37142 -0
  23. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.12.0.json +37132 -0
  24. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.0.json +37133 -0
  25. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.6.json +37133 -0
  26. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.8.9.json +37132 -0
  27. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper.9.0.json +74262 -0
  28. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.10.0.json +48482 -0
  29. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.12.0.json +55692 -0
  30. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.0.json +55693 -0
  31. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.6.json +55692 -0
  32. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.8.9.json +55693 -0
  33. cuequivariance_ops/triton/cache/fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper.9.0.json +111382 -0
  34. cuequivariance_ops/triton/cache_manager.py +336 -0
  35. cuequivariance_ops/triton/fused_layer_norm_triton.py +546 -0
  36. cuequivariance_ops/triton/gated_gemm_triton.py +394 -0
  37. cuequivariance_ops/triton/pair_bias.py +365 -0
  38. cuequivariance_ops/triton/tuning_decorator.py +188 -0
  39. cuequivariance_ops/triton/utils.py +29 -0
  40. cuequivariance_ops_cu12-0.8.1.dist-info/METADATA +182 -0
  41. cuequivariance_ops_cu12-0.8.1.dist-info/RECORD +46 -0
  42. cuequivariance_ops_cu12-0.8.1.dist-info/WHEEL +6 -0
  43. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/LICENSE +142 -0
  44. cuequivariance_ops_cu12-0.8.1.dist-info/licenses/Third_party_attr.txt +24 -0
  45. cuequivariance_ops_cu12-0.8.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  46. cuequivariance_ops_cu12.libs/libnvfatbin-b51d3b3f.so.12.8.90 +0 -0
@@ -0,0 +1,365 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ @triton.jit
16
+ def pair_bias_norm_linear_mask_forward_kernel(
17
+ z_ptr,
18
+ mask_ptr,
19
+ w_proj_z_ptr,
20
+ b_proj_z_ptr,
21
+ w_ln_ptr,
22
+ b_ln_ptr,
23
+ U,
24
+ V,
25
+ multiplicity,
26
+ out_mask_ptr,
27
+ z_norm_ptr,
28
+ z_proj_ptr,
29
+ mean_ptr,
30
+ rstd_ptr,
31
+ TILE_V: tl.constexpr,
32
+ TILE_K: tl.constexpr,
33
+ NUM_HEADS: tl.constexpr,
34
+ NUM_HEADS_PER_BLK: tl.constexpr,
35
+ DIM_Z: tl.constexpr,
36
+ INF: tl.constexpr,
37
+ EPS: tl.constexpr,
38
+ ELEMENTWISE_AFFINE: tl.constexpr,
39
+ IS_TRAINING: tl.constexpr,
40
+ HAS_BIAS: tl.constexpr,
41
+ MASK_WITH_MULTIPLICITY: tl.constexpr,
42
+ CACHE_Z_PROJ: tl.constexpr = False,
43
+ NEEDS_INT64: tl.constexpr = True,
44
+ ):
45
+ # prepare single mask
46
+ # z: B x U x V x D -> z' B x H x U x V
47
+ # mask: B x V
48
+ # out_mask = z' + (1 - mask) * inf
49
+
50
+ pid_v = tl.program_id(0)
51
+ pid_u = tl.program_id(1)
52
+
53
+ if NEEDS_INT64:
54
+ pid_u = tl.cast(pid_u, tl.int64)
55
+ pid_v = tl.cast(pid_v, tl.int64)
56
+ U = tl.cast(U, tl.int64)
57
+ V = tl.cast(V, tl.int64)
58
+
59
+ head_batch_idx = tl.program_id(2)
60
+ NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
61
+ batch_idx = head_batch_idx // NUM_HEAD_BLKS
62
+ head_idx = head_batch_idx % NUM_HEAD_BLKS
63
+
64
+ stride_vz = V * DIM_Z
65
+ stride_uv = U * V
66
+ stride_uvz = U * V * DIM_Z
67
+
68
+ offs_u = pid_u
69
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
70
+ offs_k = tl.arange(0, TILE_K)
71
+ offs_z = tl.arange(0, DIM_Z)
72
+ mask_v = offs_v < V
73
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
74
+ mask_head = offs_head < NUM_HEADS
75
+
76
+ z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
77
+ z_ptrs += offs_v[:, None] * DIM_Z + offs_z[None, :]
78
+
79
+ z_tile_full = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
80
+
81
+ mean = tl.sum(z_tile_full, axis=1) / DIM_Z
82
+ rstd = z_tile_full - mean[:, None]
83
+ rstd = rstd * rstd
84
+ rstd = tl.sum(rstd, axis=1) / DIM_Z
85
+ rstd = tl.rsqrt(rstd + EPS)
86
+
87
+ if IS_TRAINING:
88
+ mean_ptrs = mean_ptr + batch_idx * stride_uv
89
+ mean_ptrs += offs_u * V + offs_v
90
+ tl.store(mean_ptrs, mean, mask=mask_v)
91
+
92
+ rstd_ptrs = rstd_ptr + batch_idx * stride_uv
93
+ rstd_ptrs += offs_u * V + offs_v
94
+ tl.store(rstd_ptrs, rstd, mask=mask_v)
95
+
96
+ z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
97
+ z_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
98
+ w_ln_ptrs = w_ln_ptr + offs_k
99
+ b_ln_ptrs = b_ln_ptr + offs_k
100
+ w_proj_ptrs = w_proj_z_ptr + (offs_head[None, :] * DIM_Z + offs_k[:, None])
101
+
102
+ if IS_TRAINING:
103
+ z_norm_ptrs = z_norm_ptr + batch_idx * stride_uvz + offs_u * stride_vz
104
+ z_norm_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
105
+
106
+ num_tiles_k = DIM_Z // TILE_K
107
+ acc = tl.zeros((TILE_V, NUM_HEADS_PER_BLK), dtype=tl.float32)
108
+
109
+ for _ in range(0, num_tiles_k):
110
+ z_tile = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
111
+ z_tile = (z_tile - mean[:, None]) * rstd[:, None]
112
+
113
+ if ELEMENTWISE_AFFINE:
114
+ w_ln_tile = tl.load(w_ln_ptrs).to(tl.float32)
115
+ b_ln_tile = tl.load(b_ln_ptrs).to(tl.float32)
116
+ z_tile = z_tile * w_ln_tile + b_ln_tile
117
+
118
+ if IS_TRAINING:
119
+ tl.store(z_norm_ptrs, z_tile, mask=mask_v[:, None])
120
+
121
+ w_tile = tl.load(w_proj_ptrs, mask=mask_head[None, :], other=0.0).to(tl.float32)
122
+
123
+ acc = tl.dot(z_tile, w_tile, acc, input_precision="tf32x3")
124
+
125
+ z_ptrs += TILE_K
126
+ w_proj_ptrs += TILE_K
127
+ if ELEMENTWISE_AFFINE:
128
+ w_ln_ptrs += TILE_K
129
+ b_ln_ptrs += TILE_K
130
+ if IS_TRAINING:
131
+ z_norm_ptrs += TILE_K
132
+
133
+ if HAS_BIAS:
134
+ b_proj_ptrs = b_proj_z_ptr + offs_head
135
+ b_proj_tile = tl.load(b_proj_ptrs, mask=mask_head, other=0.0).to(tl.float32)
136
+ acc += b_proj_tile[None, :]
137
+
138
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
139
+ mask_v = offs_v < V
140
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
141
+ mask_head = offs_head < NUM_HEADS
142
+
143
+ if CACHE_Z_PROJ:
144
+ z_proj_ptrs = z_proj_ptr + batch_idx * NUM_HEADS * stride_uv
145
+ z_proj_ptrs += offs_u * V
146
+ z_proj_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
147
+ mask_z = mask_head[None, :] & mask_v[:, None]
148
+ tl.store(z_proj_ptrs, acc, mask=mask_z)
149
+
150
+ out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
151
+ out_mask_ptrs += offs_u * V
152
+ out_mask_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
153
+ mask_o = mask_head[None, :] & mask_v[:, None]
154
+
155
+ if MASK_WITH_MULTIPLICITY:
156
+ mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
157
+
158
+ for _ in range(multiplicity):
159
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
160
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
161
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
162
+
163
+ out_mask_ptrs += NUM_HEADS * stride_uv
164
+ mask_ptrs += V
165
+
166
+ else:
167
+ mask_ptrs = mask_ptr + batch_idx * V + offs_v
168
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
169
+
170
+ for _ in range(multiplicity):
171
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
172
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
173
+
174
+ out_mask_ptrs += NUM_HEADS * stride_uv
175
+ mask_ptrs += V
176
+
177
+
178
+ @triton.jit
179
+ def pair_bias_linear_mask_forward_kernel(
180
+ z_ptr,
181
+ mask_ptr,
182
+ w_proj_z_ptr,
183
+ b_proj_z_ptr,
184
+ U,
185
+ V,
186
+ multiplicity,
187
+ out_mask_ptr,
188
+ z_proj_ptr,
189
+ TILE_V: tl.constexpr,
190
+ TILE_K: tl.constexpr,
191
+ NUM_HEADS: tl.constexpr,
192
+ NUM_HEADS_PER_BLK: tl.constexpr,
193
+ DIM_Z: tl.constexpr,
194
+ INF: tl.constexpr,
195
+ HAS_BIAS: tl.constexpr,
196
+ MASK_WITH_MULTIPLICITY: tl.constexpr,
197
+ CACHE_Z_PROJ: tl.constexpr = False,
198
+ NEEDS_INT64: tl.constexpr = True,
199
+ ):
200
+ # prepare single mask
201
+ # z: B x U x V x D -> z' B x H x U x V
202
+ # mask: B x V
203
+ # out_mask = z' + (1 - mask) * inf
204
+
205
+ pid_v = tl.program_id(0)
206
+ pid_u = tl.program_id(1)
207
+
208
+ if NEEDS_INT64:
209
+ pid_u = tl.cast(pid_u, tl.int64)
210
+ pid_v = tl.cast(pid_v, tl.int64)
211
+ U = tl.cast(U, tl.int64)
212
+ V = tl.cast(V, tl.int64)
213
+
214
+ head_batch_idx = tl.program_id(2)
215
+ NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
216
+ batch_idx = head_batch_idx // NUM_HEAD_BLKS
217
+ head_idx = head_batch_idx % NUM_HEAD_BLKS
218
+
219
+ stride_vz = V * DIM_Z
220
+ stride_uv = U * V
221
+ stride_uvz = U * V * DIM_Z
222
+
223
+ offs_u = pid_u
224
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
225
+ offs_k = tl.arange(0, TILE_K)
226
+ mask_v = offs_v < V
227
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
228
+ mask_head = offs_head < NUM_HEADS
229
+
230
+ z_ptrs = z_ptr + batch_idx * stride_uvz + offs_u * stride_vz
231
+ z_ptrs += offs_v[:, None] * DIM_Z + offs_k[None, :]
232
+
233
+ w_ptrs = w_proj_z_ptr + (offs_head[None, :] * DIM_Z + offs_k[:, None])
234
+
235
+ acc = tl.zeros((TILE_V, NUM_HEADS_PER_BLK), dtype=tl.float32)
236
+
237
+ for _ in range(0, DIM_Z // TILE_K):
238
+ z_tile = tl.load(z_ptrs, mask=mask_v[:, None], other=0.0).to(tl.float32)
239
+ w_tile = tl.load(w_ptrs, mask=mask_head[None, :], other=0.0).to(tl.float32)
240
+
241
+ acc = tl.dot(z_tile, w_tile, acc, input_precision="tf32x3")
242
+
243
+ z_ptrs += TILE_K
244
+ w_ptrs += TILE_K
245
+
246
+ if HAS_BIAS:
247
+ b_proj_ptrs = b_proj_z_ptr + offs_head
248
+ b_proj_tile = tl.load(b_proj_ptrs, mask=mask_head, other=0.0)
249
+ acc += b_proj_tile[None, :]
250
+
251
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
252
+ mask_v = offs_v < V
253
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
254
+ mask_head = offs_head < NUM_HEADS
255
+
256
+ if CACHE_Z_PROJ:
257
+ z_proj_ptrs = z_proj_ptr + batch_idx * NUM_HEADS * stride_uv
258
+ z_proj_ptrs += offs_u * V
259
+ z_proj_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
260
+ mask_z = mask_head[None, :] & mask_v[:, None]
261
+ tl.store(z_proj_ptrs, acc, mask=mask_z)
262
+
263
+ out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
264
+ out_mask_ptrs += offs_u * V
265
+ out_mask_ptrs += offs_head[None, :] * stride_uv + offs_v[:, None]
266
+ mask_o = mask_head[None, :] & mask_v[:, None]
267
+
268
+ if MASK_WITH_MULTIPLICITY:
269
+ mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
270
+
271
+ for _ in range(multiplicity):
272
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
273
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
274
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
275
+
276
+ out_mask_ptrs += NUM_HEADS * stride_uv
277
+ mask_ptrs += V
278
+
279
+ else:
280
+ mask_ptrs = mask_ptr + batch_idx * V + offs_v
281
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
282
+
283
+ for _ in range(multiplicity):
284
+ out_tile = acc + (1.0 - mask_tile[:, None]) * (-INF)
285
+ tl.store(out_mask_ptrs, out_tile, mask=mask_o)
286
+
287
+ out_mask_ptrs += NUM_HEADS * stride_uv
288
+ mask_ptrs += V
289
+
290
+
291
+ @triton.jit
292
+ def pair_bias_mask_forward_kernel(
293
+ z_proj_ptr,
294
+ mask_ptr,
295
+ U,
296
+ V,
297
+ multiplicity,
298
+ out_mask_ptr,
299
+ TILE_V: tl.constexpr,
300
+ NUM_HEADS: tl.constexpr,
301
+ NUM_HEADS_PER_BLK: tl.constexpr,
302
+ INF: tl.constexpr,
303
+ MASK_WITH_MULTIPLICITY: tl.constexpr,
304
+ NEEDS_INT64: tl.constexpr = True,
305
+ ):
306
+ # prepare single mask
307
+ # z: z' B x H x U x V
308
+ # mask: B x V
309
+ # out_mask = z' + (1 - mask) * inf
310
+
311
+ pid_v = tl.program_id(0)
312
+ pid_u = tl.program_id(1)
313
+
314
+ if NEEDS_INT64:
315
+ pid_u = tl.cast(pid_u, tl.int64)
316
+ pid_v = tl.cast(pid_v, tl.int64)
317
+ U = tl.cast(U, tl.int64)
318
+ V = tl.cast(V, tl.int64)
319
+
320
+ head_batch_idx = tl.program_id(2)
321
+ NUM_HEAD_BLKS = tl.cdiv(NUM_HEADS, NUM_HEADS_PER_BLK)
322
+ batch_idx = head_batch_idx // NUM_HEAD_BLKS
323
+ head_idx = head_batch_idx % NUM_HEAD_BLKS
324
+
325
+ stride_uv = U * V
326
+ stride_uvh = U * V * NUM_HEADS
327
+
328
+ offs_u = pid_u
329
+ offs_v = pid_v * TILE_V + tl.arange(0, TILE_V)
330
+ mask_v = offs_v < V
331
+ offs_head = head_idx * NUM_HEADS_PER_BLK + tl.arange(0, NUM_HEADS_PER_BLK)
332
+ mask_head = offs_head < NUM_HEADS
333
+
334
+ z_proj_ptrs = z_proj_ptr + batch_idx * stride_uvh + offs_u * V
335
+ z_proj_ptrs += offs_head[:, None] * stride_uv + offs_v[None, :]
336
+
337
+ mask_zo = mask_v[None, :] & mask_head[:, None]
338
+
339
+ z_proj_tile = tl.load(z_proj_ptrs, mask=mask_zo, other=0.0).to(tl.float32)
340
+
341
+ out_mask_ptrs = out_mask_ptr + batch_idx * multiplicity * NUM_HEADS * stride_uv
342
+ out_mask_ptrs += offs_u * V
343
+ out_mask_ptrs += offs_head[:, None] * stride_uv + offs_v[None, :]
344
+
345
+ if MASK_WITH_MULTIPLICITY:
346
+ mask_ptrs = mask_ptr + batch_idx * multiplicity * V + offs_v
347
+
348
+ for _ in range(multiplicity):
349
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
350
+ out_tile = z_proj_tile + (1.0 - mask_tile[None, :]) * (-INF)
351
+ tl.store(out_mask_ptrs, out_tile, mask=mask_zo)
352
+
353
+ out_mask_ptrs += NUM_HEADS * stride_uv
354
+ mask_ptrs += V
355
+
356
+ else:
357
+ mask_ptrs = mask_ptr + batch_idx * V + offs_v
358
+ mask_tile = tl.load(mask_ptrs, mask=mask_v, other=0.0).to(tl.float32)
359
+
360
+ for _ in range(multiplicity):
361
+ out_tile = z_proj_tile + (1.0 - mask_tile[None, :]) * (-INF)
362
+ tl.store(out_mask_ptrs, out_tile, mask=mask_zo)
363
+
364
+ out_mask_ptrs += NUM_HEADS * stride_uv
365
+ mask_ptrs += V
@@ -0,0 +1,188 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import gc
12
+ import inspect
13
+ import logging # Added logging import
14
+ from typing import Any, Callable
15
+
16
+ from tqdm import tqdm
17
+
18
+ from .cache_manager import get_cache_manager
19
+
20
+ # import torch
21
+
22
+ # Configure logging
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def input_to_key_default(**args) -> str:
27
+ key_parts = []
28
+ for arg in args:
29
+ if hasattr(arg, "shape") and hasattr(arg, "dtype"):
30
+ key_parts.append(f"{list(arg.shape)}_{arg.dtype}")
31
+ elif isinstance(arg, bool):
32
+ key_parts.append("True" if arg else "False")
33
+ elif isinstance(arg, str):
34
+ key_parts.append(arg)
35
+ else:
36
+ key_parts.append(str(arg.__class__.__name__))
37
+
38
+ return "_".join(key_parts)
39
+
40
+
41
+ def combine_all_kwargs(
42
+ fn: Callable,
43
+ args: tuple,
44
+ kwargs: dict[str, Any],
45
+ ) -> dict[str, Any]:
46
+ # Get the function signature
47
+ sig = inspect.signature(fn)
48
+ params = sig.parameters
49
+ param_names = list(params.keys())
50
+
51
+ # Create dictionary of default values
52
+ defaults = {
53
+ name: param.default
54
+ for name, param in params.items()
55
+ if param.default is not inspect.Parameter.empty
56
+ }
57
+ # Create dictionary mapping positional args to parameter names
58
+ args_as_kwargs = {
59
+ param_names[i]: args[i] for i in range(min(len(args), len(param_names)))
60
+ }
61
+ # Create combined dictionary of all parameters
62
+ all_kwargs = defaults.copy() # Start with defaults
63
+ all_kwargs.update(args_as_kwargs) # Override with positional args
64
+ all_kwargs.update(kwargs) # Override with explicit kwargs
65
+
66
+ return all_kwargs
67
+
68
+
69
+ def autotune_aot(
70
+ input_generator: Callable,
71
+ input_to_key: Callable | None,
72
+ input_configs: list[dict[str, Any]],
73
+ tunable_configs: list[dict[str, Any]],
74
+ prune_configs_fn: Callable[
75
+ [list[dict[str, Any]], dict[str, Any]], list[dict[str, Any]]
76
+ ]
77
+ | None,
78
+ run_decoy: Callable[[Callable, dict[str, Any]], None],
79
+ run_bench: Callable[[Callable, dict[str, Any]], float],
80
+ ) -> None:
81
+ def decorator(fn: Callable) -> Callable:
82
+ def wrapper(*args, **kwargs):
83
+ all_kwargs = combine_all_kwargs(fn, args, kwargs)
84
+ nonlocal input_to_key
85
+ nonlocal input_configs
86
+
87
+ if input_to_key is None:
88
+ input_to_key = input_to_key_default
89
+
90
+ # Check if the function is already cached
91
+ function_key = fn.__name__
92
+ input_key = input_to_key(**all_kwargs)
93
+ cache_manager = get_cache_manager()
94
+ best_cached_config = cache_manager.get(function_key, input_key)
95
+
96
+ aot_mode = cache_manager.aot_mode
97
+
98
+ if best_cached_config is None and aot_mode is not None:
99
+ # start autotuning process
100
+ # input_configs = input_configs + [None]
101
+ if aot_mode == "ONDEMAND":
102
+ input_configs = [None]
103
+
104
+ try:
105
+ # Initialize the progress bar
106
+ progress_bar = tqdm(
107
+ input_configs, desc="Autotuning Progress", unit="config"
108
+ )
109
+
110
+ for input_config in progress_bar:
111
+ # generate input based on the config
112
+ input_data = (
113
+ input_generator(**input_config)
114
+ if input_config is not None
115
+ else all_kwargs
116
+ )
117
+
118
+ # Make a copy of all_kwargs to avoid modifying the original
119
+ current_kwargs = all_kwargs.copy()
120
+ current_kwargs.update(input_data)
121
+ current_input_key = input_to_key(**current_kwargs)
122
+
123
+ best_cached_config = cache_manager.get(
124
+ function_key, current_input_key
125
+ )
126
+
127
+ if best_cached_config is not None:
128
+ continue
129
+
130
+ # print(f"Running for key: {current_input_key}")
131
+
132
+ # prune the tunable configs based on the all_kwargs
133
+ pruned_tunable_configs = (
134
+ prune_configs_fn(tunable_configs, **all_kwargs)
135
+ if prune_configs_fn is not None
136
+ else tunable_configs
137
+ )
138
+
139
+ best_config = None
140
+ best_time = float("inf")
141
+ working_config = []
142
+ for tunable in pruned_tunable_configs:
143
+ try:
144
+ current_kwargs.update(tunable)
145
+ run_decoy(fn, current_kwargs)
146
+ working_config.append(tunable)
147
+ except Exception:
148
+ pass
149
+
150
+ if not working_config:
151
+ logger.warning(
152
+ f"No valid configurations found for input: {current_input_key}"
153
+ )
154
+ continue
155
+
156
+ for tunable in working_config:
157
+ current_kwargs.update(tunable)
158
+ elapse = run_bench(fn, current_kwargs)
159
+ if elapse < best_time:
160
+ best_time = elapse
161
+ best_config = tunable
162
+
163
+ cache_manager.set(
164
+ function_key,
165
+ current_input_key,
166
+ {"config": best_config, "time": best_time},
167
+ )
168
+ current_kwargs = None
169
+ input_data = None
170
+ gc.collect()
171
+ # torch.cuda.empty_cache()
172
+ if (progress_bar.n % 1000) == 1:
173
+ cache_manager.save_cache(function_key)
174
+ cache_manager.save_cache(function_key)
175
+ except Exception as e:
176
+ print(f"Stopping autotuning due to error: {e}")
177
+
178
+ # After tuning, try to get the best config
179
+ best_cached_config = cache_manager.get(function_key, input_key)
180
+
181
+ if best_cached_config is not None:
182
+ all_kwargs.update(best_cached_config["config"])
183
+
184
+ return fn(**all_kwargs)
185
+
186
+ return wrapper
187
+
188
+ return decorator
@@ -0,0 +1,29 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3
+ #
4
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ # property and proprietary rights in and to this material, related
6
+ # documentation and any modifications thereto. Any use, reproduction,
7
+ # disclosure or distribution of this material and related documentation
8
+ # without an express license agreement from NVIDIA CORPORATION or
9
+ # its affiliates is strictly prohibited.
10
+
11
+ import enum
12
+
13
+ import triton
14
+ import triton.language as tl
15
+
16
+
17
+ class Precision(enum.Enum):
18
+ DEFAULT = 0
19
+ TF32 = 1
20
+ TF32x3 = 2
21
+ IEEE = 3
22
+ NONE = -1
23
+
24
+
25
+ @triton.jit
26
+ def cvt_tf32_rn(x: tl.tensor) -> tl.tensor:
27
+ return tl.inline_asm_elementwise(
28
+ "cvt.rna.tf32.f32 $0, $1;", "=r, r", [x], dtype=tl.float32, is_pure=True, pack=1
29
+ )