fa4 4.0.0b3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. fa4-4.0.0b3/.flake8 +4 -0
  2. fa4-4.0.0b3/AUTHORS +5 -0
  3. fa4-4.0.0b3/LICENSE +29 -0
  4. fa4-4.0.0b3/MANIFEST.in +5 -0
  5. fa4-4.0.0b3/PKG-INFO +57 -0
  6. fa4-4.0.0b3/README.md +26 -0
  7. fa4-4.0.0b3/__init__.py +26 -0
  8. fa4-4.0.0b3/ampere_helpers.py +103 -0
  9. fa4-4.0.0b3/barrier.py +71 -0
  10. fa4-4.0.0b3/benchmark.py +268 -0
  11. fa4-4.0.0b3/blackwell_helpers.py +1089 -0
  12. fa4-4.0.0b3/block_info.py +108 -0
  13. fa4-4.0.0b3/block_sparse_utils.py +1476 -0
  14. fa4-4.0.0b3/block_sparsity.py +440 -0
  15. fa4-4.0.0b3/cache_utils.py +307 -0
  16. fa4-4.0.0b3/compute_block_sparsity.py +378 -0
  17. fa4-4.0.0b3/copy_utils.py +372 -0
  18. fa4-4.0.0b3/cute_dsl_ptxas.py +151 -0
  19. fa4-4.0.0b3/cute_dsl_utils.py +167 -0
  20. fa4-4.0.0b3/dense_gemm_persistent.py +2190 -0
  21. fa4-4.0.0b3/fa4.egg-info/SOURCES.txt +80 -0
  22. fa4-4.0.0b3/fast_math.py +21 -0
  23. fa4-4.0.0b3/flash_bwd.py +1264 -0
  24. fa4-4.0.0b3/flash_bwd_postprocess.py +585 -0
  25. fa4-4.0.0b3/flash_bwd_preprocess.py +361 -0
  26. fa4-4.0.0b3/flash_bwd_sm100.py +3974 -0
  27. fa4-4.0.0b3/flash_bwd_sm90.py +1591 -0
  28. fa4-4.0.0b3/flash_fwd.py +2426 -0
  29. fa4-4.0.0b3/flash_fwd_combine.py +692 -0
  30. fa4-4.0.0b3/flash_fwd_epitile.py +2467 -0
  31. fa4-4.0.0b3/flash_fwd_sm100.py +2842 -0
  32. fa4-4.0.0b3/flash_fwd_sm100_nopipeline.py +1833 -0
  33. fa4-4.0.0b3/flash_launch.py +235 -0
  34. fa4-4.0.0b3/interface.py +1855 -0
  35. fa4-4.0.0b3/mask.py +653 -0
  36. fa4-4.0.0b3/mma_sm100_desc.py +296 -0
  37. fa4-4.0.0b3/named_barrier.py +32 -0
  38. fa4-4.0.0b3/pack_gqa.py +165 -0
  39. fa4-4.0.0b3/paged_kv.py +214 -0
  40. fa4-4.0.0b3/pipeline.py +440 -0
  41. fa4-4.0.0b3/pyproject.toml +64 -0
  42. fa4-4.0.0b3/sass_patch.py +209 -0
  43. fa4-4.0.0b3/seqlen_info.py +138 -0
  44. fa4-4.0.0b3/setup.cfg +4 -0
  45. fa4-4.0.0b3/softmax.py +592 -0
  46. fa4-4.0.0b3/test_flash_fwd_combine.py +125 -0
  47. fa4-4.0.0b3/testing.py +456 -0
  48. fa4-4.0.0b3/tile_scheduler.py +727 -0
  49. fa4-4.0.0b3/utils.py +698 -0
fa4-4.0.0b3/.flake8 ADDED
@@ -0,0 +1,4 @@
1
+ [flake8]
2
+ max-line-length = 100
3
+ # W503: line break before binary operator
4
+ ignore = E731, E741, F841, W503
fa4-4.0.0b3/AUTHORS ADDED
@@ -0,0 +1,5 @@
1
+ Tri Dao, tri@tridao.me
2
+ Jay Shah
3
+ Ted Zadouri
4
+ Markus Hoehnerbach
5
+ Vijay Thakkar
fa4-4.0.0b3/LICENSE ADDED
@@ -0,0 +1,29 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ * Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ * Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ * Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
@@ -0,0 +1,5 @@
1
+ global-exclude *.egg-info/*
2
+ prune flash_attn_4.egg-info
3
+ prune flash_attn.egg-info
4
+ prune build
5
+ prune dist
fa4-4.0.0b3/PKG-INFO ADDED
@@ -0,0 +1,57 @@
1
+ Metadata-Version: 2.4
2
+ Name: fa4
3
+ Version: 4.0.0b3
4
+ Summary: Flash Attention CUTE (CUDA Template Engine) implementation
5
+ Author: Tri Dao
6
+ License: BSD 3-Clause License
7
+ Project-URL: Homepage, https://github.com/Dao-AILab/flash-attention
8
+ Project-URL: Repository, https://github.com/Dao-AILab/flash-attention
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: License :: OSI Approved :: BSD License
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Requires-Python: >=3.10
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ License-File: AUTHORS
19
+ Requires-Dist: nvidia-cutlass-dsl>=4.4.1
20
+ Requires-Dist: torch
21
+ Requires-Dist: einops
22
+ Requires-Dist: typing_extensions
23
+ Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
24
+ Requires-Dist: torch-c-dlpack-ext
25
+ Requires-Dist: quack-kernels>=0.2.10
26
+ Requires-Dist: setuptools
27
+ Provides-Extra: dev
28
+ Requires-Dist: pytest; extra == "dev"
29
+ Requires-Dist: ruff; extra == "dev"
30
+ Dynamic: license-file
31
+
32
+ # FlashAttention-4 (CuTeDSL)
33
+
34
+ FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs.
35
+
36
+ ## Installation
37
+
38
+ ```sh
39
+ pip install flash-attn4
40
+ ```
41
+
42
+ ## Usage
43
+
44
+ ```python
45
+ from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
46
+
47
+ out = flash_attn_func(q, k, v, causal=True)
48
+ ```
49
+
50
+ ## Development
51
+
52
+ ```sh
53
+ git clone https://github.com/Dao-AILab/flash-attention.git
54
+ cd flash-attention
55
+ pip install -e "flash_attn/cute[dev]"
56
+ pytest tests/cute/
57
+ ```
fa4-4.0.0b3/README.md ADDED
@@ -0,0 +1,26 @@
1
+ # FlashAttention-4 (CuTeDSL)
2
+
3
+ FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs.
4
+
5
+ ## Installation
6
+
7
+ ```sh
8
+ pip install flash-attn4
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
15
+
16
+ out = flash_attn_func(q, k, v, causal=True)
17
+ ```
18
+
19
+ ## Development
20
+
21
+ ```sh
22
+ git clone https://github.com/Dao-AILab/flash-attention.git
23
+ cd flash-attention
24
+ pip install -e "flash_attn/cute[dev]"
25
+ pytest tests/cute/
26
+ ```
@@ -0,0 +1,26 @@
1
+ """Flash Attention CUTE (CUDA Template Engine) implementation."""
2
+
3
+ from importlib.metadata import PackageNotFoundError, version
4
+
5
+ try:
6
+ __version__ = version("flash-attn4")
7
+ except PackageNotFoundError:
8
+ __version__ = "0.0.0"
9
+
10
+ import cutlass.cute as cute
11
+
12
+ from .interface import (
13
+ flash_attn_func,
14
+ flash_attn_varlen_func,
15
+ )
16
+
17
+ from flash_attn.cute.cute_dsl_utils import cute_compile_patched
18
+
19
+ # Patch cute.compile to optionally dump SASS
20
+ cute.compile = cute_compile_patched
21
+
22
+
23
+ __all__ = [
24
+ "flash_attn_func",
25
+ "flash_attn_varlen_func",
26
+ ]
@@ -0,0 +1,103 @@
1
+ # Copyright (c) 2025, Tri Dao.
2
+ from typing import Type, Callable, Optional
3
+
4
+ import cutlass
5
+ import cutlass.cute as cute
6
+
7
+
8
+ def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
9
+ dtype_byte = cutlass.const_expr(dtype.width // 8)
10
+ bytes_per_row = cutlass.const_expr(k_dim * dtype_byte)
11
+ smem_k_block_size = (
12
+ cutlass.const_expr(
13
+ 128
14
+ if bytes_per_row % 128 == 0
15
+ else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
16
+ )
17
+ // dtype_byte
18
+ )
19
+ swizzle_bits = (
20
+ 4
21
+ if smem_k_block_size == 128
22
+ else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
23
+ )
24
+ swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
25
+ return cute.make_composed_layout(
26
+ cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
27
+ 0,
28
+ cute.make_ordered_layout(
29
+ (8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)
30
+ ),
31
+ )
32
+
33
+
34
+ @cute.jit
35
+ def gemm(
36
+ tiled_mma: cute.TiledMma,
37
+ acc: cute.Tensor,
38
+ tCrA: cute.Tensor,
39
+ tCrB: cute.Tensor,
40
+ tCsA: cute.Tensor,
41
+ tCsB: cute.Tensor,
42
+ smem_thr_copy_A: cute.TiledCopy,
43
+ smem_thr_copy_B: cute.TiledCopy,
44
+ hook_fn: Optional[Callable] = None,
45
+ A_in_regs: cutlass.Constexpr[bool] = False,
46
+ B_in_regs: cutlass.Constexpr[bool] = False,
47
+ swap_AB: cutlass.Constexpr[bool] = False,
48
+ ) -> None:
49
+ if cutlass.const_expr(swap_AB):
50
+ gemm(
51
+ tiled_mma,
52
+ acc,
53
+ tCrB,
54
+ tCrA,
55
+ tCsB,
56
+ tCsA,
57
+ smem_thr_copy_B,
58
+ smem_thr_copy_A,
59
+ hook_fn,
60
+ A_in_regs=B_in_regs,
61
+ B_in_regs=A_in_regs,
62
+ swap_AB=False,
63
+ )
64
+ else:
65
+ tCrA_copy_view = smem_thr_copy_A.retile(tCrA)
66
+ tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
67
+ if cutlass.const_expr(not A_in_regs):
68
+ cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0])
69
+ if cutlass.const_expr(not B_in_regs):
70
+ cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
71
+ for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])):
72
+ if k < cute.size(tCsA.shape[2]) - 1:
73
+ if cutlass.const_expr(not A_in_regs):
74
+ cute.copy(
75
+ smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]
76
+ )
77
+ if cutlass.const_expr(not B_in_regs):
78
+ cute.copy(
79
+ smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]
80
+ )
81
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
82
+ if cutlass.const_expr(k == 0 and hook_fn is not None):
83
+ hook_fn()
84
+
85
+
86
+ @cute.jit
87
+ def gemm_rs(
88
+ tiled_mma: cute.TiledMma,
89
+ acc: cute.Tensor,
90
+ tCrA: cute.Tensor,
91
+ tCrB: cute.Tensor,
92
+ tCsB: cute.Tensor,
93
+ smem_thr_copy_B: cute.TiledCopy,
94
+ hook_fn: Optional[Callable] = None,
95
+ ) -> None:
96
+ tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
97
+ cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
98
+ for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
99
+ if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1):
100
+ cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1])
101
+ cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
102
+ if cutlass.const_expr(k == 0 and hook_fn is not None):
103
+ hook_fn()
fa4-4.0.0b3/barrier.py ADDED
@@ -0,0 +1,71 @@
1
+ import cutlass
2
+ import cutlass.cute as cute
3
+ from cutlass import Int32
4
+ from cutlass.cutlass_dsl import T, dsl_user_op
5
+ from cutlass._mlir.dialects import llvm
6
+
7
+
8
+ @dsl_user_op
9
+ def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
10
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
11
+ state = llvm.inline_asm(
12
+ T.i32(),
13
+ [lock_ptr_i64],
14
+ "ld.global.acquire.gpu.b32 $0, [$1];",
15
+ "=r,l",
16
+ has_side_effects=True,
17
+ is_align_stack=False,
18
+ asm_dialect=llvm.AsmDialect.AD_ATT,
19
+ )
20
+ return cutlass.Int32(state)
21
+
22
+
23
+ @dsl_user_op
24
+ def red_relaxed(
25
+ lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
26
+ ) -> None:
27
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
28
+ llvm.inline_asm(
29
+ None,
30
+ [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
31
+ "red.relaxed.gpu.global.add.s32 [$0], $1;",
32
+ "l,r",
33
+ has_side_effects=True,
34
+ is_align_stack=False,
35
+ asm_dialect=llvm.AsmDialect.AD_ATT,
36
+ )
37
+
38
+
39
+ @dsl_user_op
40
+ def red_release(
41
+ lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
42
+ ) -> None:
43
+ lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
44
+ llvm.inline_asm(
45
+ None,
46
+ [lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
47
+ "red.release.gpu.global.add.s32 [$0], $1;",
48
+ "l,r",
49
+ has_side_effects=True,
50
+ is_align_stack=False,
51
+ asm_dialect=llvm.AsmDialect.AD_ATT,
52
+ )
53
+
54
+
55
+ @cute.jit
56
+ def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
57
+ flag_ptr = lock_ptr + flag_offset
58
+ if thread_idx == 0:
59
+ read_val = Int32(0)
60
+ while read_val != val:
61
+ read_val = ld_acquire(flag_ptr)
62
+
63
+
64
+ @cute.jit
65
+ def arrive_inc(
66
+ lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
67
+ ) -> None:
68
+ flag_ptr = lock_ptr + flag_offset
69
+ if thread_idx == 0:
70
+ red_release(flag_ptr, val)
71
+ # red_relaxed(flag_ptr, val)
@@ -0,0 +1,268 @@
1
+ # Copyright (c) 2023, Tri Dao.
2
+ """Useful functions for writing test code."""
3
+
4
+ import torch
5
+ import torch.utils.benchmark as benchmark
6
+
7
+
8
+ def benchmark_forward(
9
+ fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
10
+ ):
11
+ """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
12
+ if verbose:
13
+ print(desc, "- Forward pass")
14
+
15
+ def amp_wrapper(*inputs, **kwinputs):
16
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
17
+ fn(*inputs, **kwinputs)
18
+
19
+ t = benchmark.Timer(
20
+ stmt="fn_amp(*inputs, **kwinputs)",
21
+ globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
22
+ num_threads=torch.get_num_threads(),
23
+ )
24
+ m = t.timeit(repeats)
25
+ if verbose:
26
+ print(m)
27
+ return t, m
28
+
29
+
30
+ def benchmark_backward(
31
+ fn,
32
+ *inputs,
33
+ grad=None,
34
+ repeats=10,
35
+ desc="",
36
+ verbose=True,
37
+ amp=False,
38
+ amp_dtype=torch.float16,
39
+ **kwinputs,
40
+ ):
41
+ """Use Pytorch Benchmark on the backward pass of an arbitrary function."""
42
+ if verbose:
43
+ print(desc, "- Backward pass")
44
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
45
+ y = fn(*inputs, **kwinputs)
46
+ if type(y) is tuple:
47
+ y = y[0]
48
+ if grad is None:
49
+ grad = torch.randn_like(y)
50
+ else:
51
+ if grad.shape != y.shape:
52
+ raise RuntimeError("Grad shape does not match output shape")
53
+
54
+ def f(*inputs, y, grad):
55
+ # Set .grad to None to avoid extra operation of gradient accumulation
56
+ for x in inputs:
57
+ if isinstance(x, torch.Tensor):
58
+ x.grad = None
59
+ y.backward(grad, retain_graph=True)
60
+
61
+ t = benchmark.Timer(
62
+ stmt="f(*inputs, y=y, grad=grad)",
63
+ globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
64
+ num_threads=torch.get_num_threads(),
65
+ )
66
+ m = t.timeit(repeats)
67
+ if verbose:
68
+ print(m)
69
+ return t, m
70
+
71
+
72
+ def benchmark_combined(
73
+ fn,
74
+ *inputs,
75
+ grad=None,
76
+ repeats=10,
77
+ desc="",
78
+ verbose=True,
79
+ amp=False,
80
+ amp_dtype=torch.float16,
81
+ **kwinputs,
82
+ ):
83
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
84
+ if verbose:
85
+ print(desc, "- Forward + Backward pass")
86
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
87
+ y = fn(*inputs, **kwinputs)
88
+ if type(y) is tuple:
89
+ y = y[0]
90
+ if grad is None:
91
+ grad = torch.randn_like(y)
92
+ else:
93
+ if grad.shape != y.shape:
94
+ raise RuntimeError("Grad shape does not match output shape")
95
+
96
+ def f(grad, *inputs, **kwinputs):
97
+ for x in inputs:
98
+ if isinstance(x, torch.Tensor):
99
+ x.grad = None
100
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
101
+ y = fn(*inputs, **kwinputs)
102
+ if type(y) is tuple:
103
+ y = y[0]
104
+ y.backward(grad, retain_graph=True)
105
+
106
+ t = benchmark.Timer(
107
+ stmt="f(grad, *inputs, **kwinputs)",
108
+ globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
109
+ num_threads=torch.get_num_threads(),
110
+ )
111
+ m = t.timeit(repeats)
112
+ if verbose:
113
+ print(m)
114
+ return t, m
115
+
116
+
117
+ def benchmark_fwd_bwd(
118
+ fn,
119
+ *inputs,
120
+ grad=None,
121
+ repeats=10,
122
+ desc="",
123
+ verbose=True,
124
+ amp=False,
125
+ amp_dtype=torch.float16,
126
+ **kwinputs,
127
+ ):
128
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
129
+ return (
130
+ benchmark_forward(
131
+ fn,
132
+ *inputs,
133
+ repeats=repeats,
134
+ desc=desc,
135
+ verbose=verbose,
136
+ amp=amp,
137
+ amp_dtype=amp_dtype,
138
+ **kwinputs,
139
+ ),
140
+ benchmark_backward(
141
+ fn,
142
+ *inputs,
143
+ grad=grad,
144
+ repeats=repeats,
145
+ desc=desc,
146
+ verbose=verbose,
147
+ amp=amp,
148
+ amp_dtype=amp_dtype,
149
+ **kwinputs,
150
+ ),
151
+ )
152
+
153
+
154
+ def benchmark_all(
155
+ fn,
156
+ *inputs,
157
+ grad=None,
158
+ repeats=10,
159
+ desc="",
160
+ verbose=True,
161
+ amp=False,
162
+ amp_dtype=torch.float16,
163
+ **kwinputs,
164
+ ):
165
+ """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
166
+ return (
167
+ benchmark_forward(
168
+ fn,
169
+ *inputs,
170
+ repeats=repeats,
171
+ desc=desc,
172
+ verbose=verbose,
173
+ amp=amp,
174
+ amp_dtype=amp_dtype,
175
+ **kwinputs,
176
+ ),
177
+ benchmark_backward(
178
+ fn,
179
+ *inputs,
180
+ grad=grad,
181
+ repeats=repeats,
182
+ desc=desc,
183
+ verbose=verbose,
184
+ amp=amp,
185
+ amp_dtype=amp_dtype,
186
+ **kwinputs,
187
+ ),
188
+ benchmark_combined(
189
+ fn,
190
+ *inputs,
191
+ grad=grad,
192
+ repeats=repeats,
193
+ desc=desc,
194
+ verbose=verbose,
195
+ amp=amp,
196
+ amp_dtype=amp_dtype,
197
+ **kwinputs,
198
+ ),
199
+ )
200
+
201
+
202
+ def pytorch_profiler(
203
+ fn,
204
+ *inputs,
205
+ trace_filename=None,
206
+ backward=False,
207
+ amp=False,
208
+ amp_dtype=torch.float16,
209
+ cpu=False,
210
+ verbose=True,
211
+ **kwinputs,
212
+ ):
213
+ """Wrap benchmark functions in Pytorch profiler to see CUDA information."""
214
+ if backward:
215
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
216
+ out = fn(*inputs, **kwinputs)
217
+ if type(out) is tuple:
218
+ out = out[0]
219
+ g = torch.randn_like(out)
220
+ for _ in range(30): # Warm up
221
+ if backward:
222
+ for x in inputs:
223
+ if isinstance(x, torch.Tensor):
224
+ x.grad = None
225
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
226
+ out = fn(*inputs, **kwinputs)
227
+ if type(out) is tuple:
228
+ out = out[0]
229
+ # Backward should be done outside autocast
230
+ if backward:
231
+ out.backward(g, retain_graph=True)
232
+ activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
233
+ torch.profiler.ProfilerActivity.CUDA
234
+ ]
235
+ with torch.profiler.profile(
236
+ activities=activities,
237
+ record_shapes=True,
238
+ # profile_memory=True,
239
+ with_stack=True,
240
+ ) as prof:
241
+ if backward:
242
+ for x in inputs:
243
+ if isinstance(x, torch.Tensor):
244
+ x.grad = None
245
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
246
+ out = fn(*inputs, **kwinputs)
247
+ if type(out) is tuple:
248
+ out = out[0]
249
+ if backward:
250
+ out.backward(g, retain_graph=True)
251
+ if verbose:
252
+ # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
253
+ print(prof.key_averages().table(row_limit=50))
254
+ if trace_filename is not None:
255
+ prof.export_chrome_trace(trace_filename)
256
+
257
+
258
+ def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
259
+ torch.cuda.empty_cache()
260
+ torch.cuda.reset_peak_memory_stats()
261
+ torch.cuda.synchronize()
262
+ fn(*inputs, **kwinputs)
263
+ torch.cuda.synchronize()
264
+ mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
265
+ if verbose:
266
+ print(f"{desc} max memory: {mem}GB")
267
+ torch.cuda.empty_cache()
268
+ return mem