fa4 4.0.0b3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fa4-4.0.0b3/.flake8 +4 -0
- fa4-4.0.0b3/AUTHORS +5 -0
- fa4-4.0.0b3/LICENSE +29 -0
- fa4-4.0.0b3/MANIFEST.in +5 -0
- fa4-4.0.0b3/PKG-INFO +57 -0
- fa4-4.0.0b3/README.md +26 -0
- fa4-4.0.0b3/__init__.py +26 -0
- fa4-4.0.0b3/ampere_helpers.py +103 -0
- fa4-4.0.0b3/barrier.py +71 -0
- fa4-4.0.0b3/benchmark.py +268 -0
- fa4-4.0.0b3/blackwell_helpers.py +1089 -0
- fa4-4.0.0b3/block_info.py +108 -0
- fa4-4.0.0b3/block_sparse_utils.py +1476 -0
- fa4-4.0.0b3/block_sparsity.py +440 -0
- fa4-4.0.0b3/cache_utils.py +307 -0
- fa4-4.0.0b3/compute_block_sparsity.py +378 -0
- fa4-4.0.0b3/copy_utils.py +372 -0
- fa4-4.0.0b3/cute_dsl_ptxas.py +151 -0
- fa4-4.0.0b3/cute_dsl_utils.py +167 -0
- fa4-4.0.0b3/dense_gemm_persistent.py +2190 -0
- fa4-4.0.0b3/fa4.egg-info/SOURCES.txt +80 -0
- fa4-4.0.0b3/fast_math.py +21 -0
- fa4-4.0.0b3/flash_bwd.py +1264 -0
- fa4-4.0.0b3/flash_bwd_postprocess.py +585 -0
- fa4-4.0.0b3/flash_bwd_preprocess.py +361 -0
- fa4-4.0.0b3/flash_bwd_sm100.py +3974 -0
- fa4-4.0.0b3/flash_bwd_sm90.py +1591 -0
- fa4-4.0.0b3/flash_fwd.py +2426 -0
- fa4-4.0.0b3/flash_fwd_combine.py +692 -0
- fa4-4.0.0b3/flash_fwd_epitile.py +2467 -0
- fa4-4.0.0b3/flash_fwd_sm100.py +2842 -0
- fa4-4.0.0b3/flash_fwd_sm100_nopipeline.py +1833 -0
- fa4-4.0.0b3/flash_launch.py +235 -0
- fa4-4.0.0b3/interface.py +1855 -0
- fa4-4.0.0b3/mask.py +653 -0
- fa4-4.0.0b3/mma_sm100_desc.py +296 -0
- fa4-4.0.0b3/named_barrier.py +32 -0
- fa4-4.0.0b3/pack_gqa.py +165 -0
- fa4-4.0.0b3/paged_kv.py +214 -0
- fa4-4.0.0b3/pipeline.py +440 -0
- fa4-4.0.0b3/pyproject.toml +64 -0
- fa4-4.0.0b3/sass_patch.py +209 -0
- fa4-4.0.0b3/seqlen_info.py +138 -0
- fa4-4.0.0b3/setup.cfg +4 -0
- fa4-4.0.0b3/softmax.py +592 -0
- fa4-4.0.0b3/test_flash_fwd_combine.py +125 -0
- fa4-4.0.0b3/testing.py +456 -0
- fa4-4.0.0b3/tile_scheduler.py +727 -0
- fa4-4.0.0b3/utils.py +698 -0
fa4-4.0.0b3/.flake8
ADDED
fa4-4.0.0b3/AUTHORS
ADDED
fa4-4.0.0b3/LICENSE
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
|
4
|
+
All rights reserved.
|
|
5
|
+
|
|
6
|
+
Redistribution and use in source and binary forms, with or without
|
|
7
|
+
modification, are permitted provided that the following conditions are met:
|
|
8
|
+
|
|
9
|
+
* Redistributions of source code must retain the above copyright notice, this
|
|
10
|
+
list of conditions and the following disclaimer.
|
|
11
|
+
|
|
12
|
+
* Redistributions in binary form must reproduce the above copyright notice,
|
|
13
|
+
this list of conditions and the following disclaimer in the documentation
|
|
14
|
+
and/or other materials provided with the distribution.
|
|
15
|
+
|
|
16
|
+
* Neither the name of the copyright holder nor the names of its
|
|
17
|
+
contributors may be used to endorse or promote products derived from
|
|
18
|
+
this software without specific prior written permission.
|
|
19
|
+
|
|
20
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
21
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
22
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
23
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
24
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
25
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
26
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
27
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
28
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
29
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
fa4-4.0.0b3/MANIFEST.in
ADDED
fa4-4.0.0b3/PKG-INFO
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: fa4
|
|
3
|
+
Version: 4.0.0b3
|
|
4
|
+
Summary: Flash Attention CUTE (CUDA Template Engine) implementation
|
|
5
|
+
Author: Tri Dao
|
|
6
|
+
License: BSD 3-Clause License
|
|
7
|
+
Project-URL: Homepage, https://github.com/Dao-AILab/flash-attention
|
|
8
|
+
Project-URL: Repository, https://github.com/Dao-AILab/flash-attention
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Requires-Python: >=3.10
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
License-File: AUTHORS
|
|
19
|
+
Requires-Dist: nvidia-cutlass-dsl>=4.4.1
|
|
20
|
+
Requires-Dist: torch
|
|
21
|
+
Requires-Dist: einops
|
|
22
|
+
Requires-Dist: typing_extensions
|
|
23
|
+
Requires-Dist: apache-tvm-ffi<0.2,>=0.1.5
|
|
24
|
+
Requires-Dist: torch-c-dlpack-ext
|
|
25
|
+
Requires-Dist: quack-kernels>=0.2.10
|
|
26
|
+
Requires-Dist: setuptools
|
|
27
|
+
Provides-Extra: dev
|
|
28
|
+
Requires-Dist: pytest; extra == "dev"
|
|
29
|
+
Requires-Dist: ruff; extra == "dev"
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
# FlashAttention-4 (CuTeDSL)
|
|
33
|
+
|
|
34
|
+
FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs.
|
|
35
|
+
|
|
36
|
+
## Installation
|
|
37
|
+
|
|
38
|
+
```sh
|
|
39
|
+
pip install flash-attn4
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
## Usage
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
|
|
46
|
+
|
|
47
|
+
out = flash_attn_func(q, k, v, causal=True)
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Development
|
|
51
|
+
|
|
52
|
+
```sh
|
|
53
|
+
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
54
|
+
cd flash-attention
|
|
55
|
+
pip install -e "flash_attn/cute[dev]"
|
|
56
|
+
pytest tests/cute/
|
|
57
|
+
```
|
fa4-4.0.0b3/README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# FlashAttention-4 (CuTeDSL)
|
|
2
|
+
|
|
3
|
+
FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper and Blackwell GPUs.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```sh
|
|
8
|
+
pip install flash-attn4
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
|
|
15
|
+
|
|
16
|
+
out = flash_attn_func(q, k, v, causal=True)
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Development
|
|
20
|
+
|
|
21
|
+
```sh
|
|
22
|
+
git clone https://github.com/Dao-AILab/flash-attention.git
|
|
23
|
+
cd flash-attention
|
|
24
|
+
pip install -e "flash_attn/cute[dev]"
|
|
25
|
+
pytest tests/cute/
|
|
26
|
+
```
|
fa4-4.0.0b3/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""Flash Attention CUTE (CUDA Template Engine) implementation."""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
__version__ = version("flash-attn4")
|
|
7
|
+
except PackageNotFoundError:
|
|
8
|
+
__version__ = "0.0.0"
|
|
9
|
+
|
|
10
|
+
import cutlass.cute as cute
|
|
11
|
+
|
|
12
|
+
from .interface import (
|
|
13
|
+
flash_attn_func,
|
|
14
|
+
flash_attn_varlen_func,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from flash_attn.cute.cute_dsl_utils import cute_compile_patched
|
|
18
|
+
|
|
19
|
+
# Patch cute.compile to optionally dump SASS
|
|
20
|
+
cute.compile = cute_compile_patched
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"flash_attn_func",
|
|
25
|
+
"flash_attn_varlen_func",
|
|
26
|
+
]
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Copyright (c) 2025, Tri Dao.
|
|
2
|
+
from typing import Type, Callable, Optional
|
|
3
|
+
|
|
4
|
+
import cutlass
|
|
5
|
+
import cutlass.cute as cute
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_smem_layout_atom(dtype: Type[cutlass.Numeric], k_dim: int) -> cute.ComposedLayout:
|
|
9
|
+
dtype_byte = cutlass.const_expr(dtype.width // 8)
|
|
10
|
+
bytes_per_row = cutlass.const_expr(k_dim * dtype_byte)
|
|
11
|
+
smem_k_block_size = (
|
|
12
|
+
cutlass.const_expr(
|
|
13
|
+
128
|
|
14
|
+
if bytes_per_row % 128 == 0
|
|
15
|
+
else (64 if bytes_per_row % 64 == 0 else (32 if bytes_per_row % 32 == 0 else 16))
|
|
16
|
+
)
|
|
17
|
+
// dtype_byte
|
|
18
|
+
)
|
|
19
|
+
swizzle_bits = (
|
|
20
|
+
4
|
|
21
|
+
if smem_k_block_size == 128
|
|
22
|
+
else (3 if smem_k_block_size == 64 else (2 if smem_k_block_size == 32 else 1))
|
|
23
|
+
)
|
|
24
|
+
swizzle_base = 2 if dtype_byte == 4 else (3 if dtype_byte == 2 else 4)
|
|
25
|
+
return cute.make_composed_layout(
|
|
26
|
+
cute.make_swizzle(swizzle_bits, swizzle_base, swizzle_base),
|
|
27
|
+
0,
|
|
28
|
+
cute.make_ordered_layout(
|
|
29
|
+
(8 if cutlass.const_expr(k_dim % 32 == 0) else 16, smem_k_block_size), order=(1, 0)
|
|
30
|
+
),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@cute.jit
|
|
35
|
+
def gemm(
|
|
36
|
+
tiled_mma: cute.TiledMma,
|
|
37
|
+
acc: cute.Tensor,
|
|
38
|
+
tCrA: cute.Tensor,
|
|
39
|
+
tCrB: cute.Tensor,
|
|
40
|
+
tCsA: cute.Tensor,
|
|
41
|
+
tCsB: cute.Tensor,
|
|
42
|
+
smem_thr_copy_A: cute.TiledCopy,
|
|
43
|
+
smem_thr_copy_B: cute.TiledCopy,
|
|
44
|
+
hook_fn: Optional[Callable] = None,
|
|
45
|
+
A_in_regs: cutlass.Constexpr[bool] = False,
|
|
46
|
+
B_in_regs: cutlass.Constexpr[bool] = False,
|
|
47
|
+
swap_AB: cutlass.Constexpr[bool] = False,
|
|
48
|
+
) -> None:
|
|
49
|
+
if cutlass.const_expr(swap_AB):
|
|
50
|
+
gemm(
|
|
51
|
+
tiled_mma,
|
|
52
|
+
acc,
|
|
53
|
+
tCrB,
|
|
54
|
+
tCrA,
|
|
55
|
+
tCsB,
|
|
56
|
+
tCsA,
|
|
57
|
+
smem_thr_copy_B,
|
|
58
|
+
smem_thr_copy_A,
|
|
59
|
+
hook_fn,
|
|
60
|
+
A_in_regs=B_in_regs,
|
|
61
|
+
B_in_regs=A_in_regs,
|
|
62
|
+
swap_AB=False,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
tCrA_copy_view = smem_thr_copy_A.retile(tCrA)
|
|
66
|
+
tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
|
|
67
|
+
if cutlass.const_expr(not A_in_regs):
|
|
68
|
+
cute.copy(smem_thr_copy_A, tCsA[None, None, 0], tCrA_copy_view[None, None, 0])
|
|
69
|
+
if cutlass.const_expr(not B_in_regs):
|
|
70
|
+
cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
|
|
71
|
+
for k in cutlass.range_constexpr(cute.size(tCsA.shape[2])):
|
|
72
|
+
if k < cute.size(tCsA.shape[2]) - 1:
|
|
73
|
+
if cutlass.const_expr(not A_in_regs):
|
|
74
|
+
cute.copy(
|
|
75
|
+
smem_thr_copy_A, tCsA[None, None, k + 1], tCrA_copy_view[None, None, k + 1]
|
|
76
|
+
)
|
|
77
|
+
if cutlass.const_expr(not B_in_regs):
|
|
78
|
+
cute.copy(
|
|
79
|
+
smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1]
|
|
80
|
+
)
|
|
81
|
+
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
|
82
|
+
if cutlass.const_expr(k == 0 and hook_fn is not None):
|
|
83
|
+
hook_fn()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@cute.jit
|
|
87
|
+
def gemm_rs(
|
|
88
|
+
tiled_mma: cute.TiledMma,
|
|
89
|
+
acc: cute.Tensor,
|
|
90
|
+
tCrA: cute.Tensor,
|
|
91
|
+
tCrB: cute.Tensor,
|
|
92
|
+
tCsB: cute.Tensor,
|
|
93
|
+
smem_thr_copy_B: cute.TiledCopy,
|
|
94
|
+
hook_fn: Optional[Callable] = None,
|
|
95
|
+
) -> None:
|
|
96
|
+
tCrB_copy_view = smem_thr_copy_B.retile(tCrB)
|
|
97
|
+
cute.copy(smem_thr_copy_B, tCsB[None, None, 0], tCrB_copy_view[None, None, 0])
|
|
98
|
+
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
|
|
99
|
+
if cutlass.const_expr(k < cute.size(tCrA.shape[2]) - 1):
|
|
100
|
+
cute.copy(smem_thr_copy_B, tCsB[None, None, k + 1], tCrB_copy_view[None, None, k + 1])
|
|
101
|
+
cute.gemm(tiled_mma, acc, tCrA[None, None, k], tCrB[None, None, k], acc)
|
|
102
|
+
if cutlass.const_expr(k == 0 and hook_fn is not None):
|
|
103
|
+
hook_fn()
|
fa4-4.0.0b3/barrier.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import cutlass
|
|
2
|
+
import cutlass.cute as cute
|
|
3
|
+
from cutlass import Int32
|
|
4
|
+
from cutlass.cutlass_dsl import T, dsl_user_op
|
|
5
|
+
from cutlass._mlir.dialects import llvm
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dsl_user_op
|
|
9
|
+
def ld_acquire(lock_ptr: cute.Pointer, *, loc=None, ip=None) -> cutlass.Int32:
|
|
10
|
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
11
|
+
state = llvm.inline_asm(
|
|
12
|
+
T.i32(),
|
|
13
|
+
[lock_ptr_i64],
|
|
14
|
+
"ld.global.acquire.gpu.b32 $0, [$1];",
|
|
15
|
+
"=r,l",
|
|
16
|
+
has_side_effects=True,
|
|
17
|
+
is_align_stack=False,
|
|
18
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
19
|
+
)
|
|
20
|
+
return cutlass.Int32(state)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dsl_user_op
|
|
24
|
+
def red_relaxed(
|
|
25
|
+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
|
|
26
|
+
) -> None:
|
|
27
|
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
28
|
+
llvm.inline_asm(
|
|
29
|
+
None,
|
|
30
|
+
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
|
|
31
|
+
"red.relaxed.gpu.global.add.s32 [$0], $1;",
|
|
32
|
+
"l,r",
|
|
33
|
+
has_side_effects=True,
|
|
34
|
+
is_align_stack=False,
|
|
35
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dsl_user_op
|
|
40
|
+
def red_release(
|
|
41
|
+
lock_ptr: cute.Pointer, val: cutlass.Constexpr[Int32], *, loc=None, ip=None
|
|
42
|
+
) -> None:
|
|
43
|
+
lock_ptr_i64 = lock_ptr.toint(loc=loc, ip=ip).ir_value()
|
|
44
|
+
llvm.inline_asm(
|
|
45
|
+
None,
|
|
46
|
+
[lock_ptr_i64, Int32(val).ir_value(loc=loc, ip=ip)],
|
|
47
|
+
"red.release.gpu.global.add.s32 [$0], $1;",
|
|
48
|
+
"l,r",
|
|
49
|
+
has_side_effects=True,
|
|
50
|
+
is_align_stack=False,
|
|
51
|
+
asm_dialect=llvm.AsmDialect.AD_ATT,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@cute.jit
|
|
56
|
+
def wait_eq(lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: Int32) -> None:
|
|
57
|
+
flag_ptr = lock_ptr + flag_offset
|
|
58
|
+
if thread_idx == 0:
|
|
59
|
+
read_val = Int32(0)
|
|
60
|
+
while read_val != val:
|
|
61
|
+
read_val = ld_acquire(flag_ptr)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@cute.jit
|
|
65
|
+
def arrive_inc(
|
|
66
|
+
lock_ptr: cute.Pointer, thread_idx: int | Int32, flag_offset: int, val: cutlass.Constexpr[Int32]
|
|
67
|
+
) -> None:
|
|
68
|
+
flag_ptr = lock_ptr + flag_offset
|
|
69
|
+
if thread_idx == 0:
|
|
70
|
+
red_release(flag_ptr, val)
|
|
71
|
+
# red_relaxed(flag_ptr, val)
|
fa4-4.0.0b3/benchmark.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
# Copyright (c) 2023, Tri Dao.
|
|
2
|
+
"""Useful functions for writing test code."""
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.utils.benchmark as benchmark
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def benchmark_forward(
|
|
9
|
+
fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
|
|
10
|
+
):
|
|
11
|
+
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
|
|
12
|
+
if verbose:
|
|
13
|
+
print(desc, "- Forward pass")
|
|
14
|
+
|
|
15
|
+
def amp_wrapper(*inputs, **kwinputs):
|
|
16
|
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
17
|
+
fn(*inputs, **kwinputs)
|
|
18
|
+
|
|
19
|
+
t = benchmark.Timer(
|
|
20
|
+
stmt="fn_amp(*inputs, **kwinputs)",
|
|
21
|
+
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
|
|
22
|
+
num_threads=torch.get_num_threads(),
|
|
23
|
+
)
|
|
24
|
+
m = t.timeit(repeats)
|
|
25
|
+
if verbose:
|
|
26
|
+
print(m)
|
|
27
|
+
return t, m
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def benchmark_backward(
|
|
31
|
+
fn,
|
|
32
|
+
*inputs,
|
|
33
|
+
grad=None,
|
|
34
|
+
repeats=10,
|
|
35
|
+
desc="",
|
|
36
|
+
verbose=True,
|
|
37
|
+
amp=False,
|
|
38
|
+
amp_dtype=torch.float16,
|
|
39
|
+
**kwinputs,
|
|
40
|
+
):
|
|
41
|
+
"""Use Pytorch Benchmark on the backward pass of an arbitrary function."""
|
|
42
|
+
if verbose:
|
|
43
|
+
print(desc, "- Backward pass")
|
|
44
|
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
45
|
+
y = fn(*inputs, **kwinputs)
|
|
46
|
+
if type(y) is tuple:
|
|
47
|
+
y = y[0]
|
|
48
|
+
if grad is None:
|
|
49
|
+
grad = torch.randn_like(y)
|
|
50
|
+
else:
|
|
51
|
+
if grad.shape != y.shape:
|
|
52
|
+
raise RuntimeError("Grad shape does not match output shape")
|
|
53
|
+
|
|
54
|
+
def f(*inputs, y, grad):
|
|
55
|
+
# Set .grad to None to avoid extra operation of gradient accumulation
|
|
56
|
+
for x in inputs:
|
|
57
|
+
if isinstance(x, torch.Tensor):
|
|
58
|
+
x.grad = None
|
|
59
|
+
y.backward(grad, retain_graph=True)
|
|
60
|
+
|
|
61
|
+
t = benchmark.Timer(
|
|
62
|
+
stmt="f(*inputs, y=y, grad=grad)",
|
|
63
|
+
globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
|
|
64
|
+
num_threads=torch.get_num_threads(),
|
|
65
|
+
)
|
|
66
|
+
m = t.timeit(repeats)
|
|
67
|
+
if verbose:
|
|
68
|
+
print(m)
|
|
69
|
+
return t, m
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def benchmark_combined(
|
|
73
|
+
fn,
|
|
74
|
+
*inputs,
|
|
75
|
+
grad=None,
|
|
76
|
+
repeats=10,
|
|
77
|
+
desc="",
|
|
78
|
+
verbose=True,
|
|
79
|
+
amp=False,
|
|
80
|
+
amp_dtype=torch.float16,
|
|
81
|
+
**kwinputs,
|
|
82
|
+
):
|
|
83
|
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
|
84
|
+
if verbose:
|
|
85
|
+
print(desc, "- Forward + Backward pass")
|
|
86
|
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
87
|
+
y = fn(*inputs, **kwinputs)
|
|
88
|
+
if type(y) is tuple:
|
|
89
|
+
y = y[0]
|
|
90
|
+
if grad is None:
|
|
91
|
+
grad = torch.randn_like(y)
|
|
92
|
+
else:
|
|
93
|
+
if grad.shape != y.shape:
|
|
94
|
+
raise RuntimeError("Grad shape does not match output shape")
|
|
95
|
+
|
|
96
|
+
def f(grad, *inputs, **kwinputs):
|
|
97
|
+
for x in inputs:
|
|
98
|
+
if isinstance(x, torch.Tensor):
|
|
99
|
+
x.grad = None
|
|
100
|
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
101
|
+
y = fn(*inputs, **kwinputs)
|
|
102
|
+
if type(y) is tuple:
|
|
103
|
+
y = y[0]
|
|
104
|
+
y.backward(grad, retain_graph=True)
|
|
105
|
+
|
|
106
|
+
t = benchmark.Timer(
|
|
107
|
+
stmt="f(grad, *inputs, **kwinputs)",
|
|
108
|
+
globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
|
|
109
|
+
num_threads=torch.get_num_threads(),
|
|
110
|
+
)
|
|
111
|
+
m = t.timeit(repeats)
|
|
112
|
+
if verbose:
|
|
113
|
+
print(m)
|
|
114
|
+
return t, m
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def benchmark_fwd_bwd(
|
|
118
|
+
fn,
|
|
119
|
+
*inputs,
|
|
120
|
+
grad=None,
|
|
121
|
+
repeats=10,
|
|
122
|
+
desc="",
|
|
123
|
+
verbose=True,
|
|
124
|
+
amp=False,
|
|
125
|
+
amp_dtype=torch.float16,
|
|
126
|
+
**kwinputs,
|
|
127
|
+
):
|
|
128
|
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
|
129
|
+
return (
|
|
130
|
+
benchmark_forward(
|
|
131
|
+
fn,
|
|
132
|
+
*inputs,
|
|
133
|
+
repeats=repeats,
|
|
134
|
+
desc=desc,
|
|
135
|
+
verbose=verbose,
|
|
136
|
+
amp=amp,
|
|
137
|
+
amp_dtype=amp_dtype,
|
|
138
|
+
**kwinputs,
|
|
139
|
+
),
|
|
140
|
+
benchmark_backward(
|
|
141
|
+
fn,
|
|
142
|
+
*inputs,
|
|
143
|
+
grad=grad,
|
|
144
|
+
repeats=repeats,
|
|
145
|
+
desc=desc,
|
|
146
|
+
verbose=verbose,
|
|
147
|
+
amp=amp,
|
|
148
|
+
amp_dtype=amp_dtype,
|
|
149
|
+
**kwinputs,
|
|
150
|
+
),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def benchmark_all(
|
|
155
|
+
fn,
|
|
156
|
+
*inputs,
|
|
157
|
+
grad=None,
|
|
158
|
+
repeats=10,
|
|
159
|
+
desc="",
|
|
160
|
+
verbose=True,
|
|
161
|
+
amp=False,
|
|
162
|
+
amp_dtype=torch.float16,
|
|
163
|
+
**kwinputs,
|
|
164
|
+
):
|
|
165
|
+
"""Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
|
|
166
|
+
return (
|
|
167
|
+
benchmark_forward(
|
|
168
|
+
fn,
|
|
169
|
+
*inputs,
|
|
170
|
+
repeats=repeats,
|
|
171
|
+
desc=desc,
|
|
172
|
+
verbose=verbose,
|
|
173
|
+
amp=amp,
|
|
174
|
+
amp_dtype=amp_dtype,
|
|
175
|
+
**kwinputs,
|
|
176
|
+
),
|
|
177
|
+
benchmark_backward(
|
|
178
|
+
fn,
|
|
179
|
+
*inputs,
|
|
180
|
+
grad=grad,
|
|
181
|
+
repeats=repeats,
|
|
182
|
+
desc=desc,
|
|
183
|
+
verbose=verbose,
|
|
184
|
+
amp=amp,
|
|
185
|
+
amp_dtype=amp_dtype,
|
|
186
|
+
**kwinputs,
|
|
187
|
+
),
|
|
188
|
+
benchmark_combined(
|
|
189
|
+
fn,
|
|
190
|
+
*inputs,
|
|
191
|
+
grad=grad,
|
|
192
|
+
repeats=repeats,
|
|
193
|
+
desc=desc,
|
|
194
|
+
verbose=verbose,
|
|
195
|
+
amp=amp,
|
|
196
|
+
amp_dtype=amp_dtype,
|
|
197
|
+
**kwinputs,
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def pytorch_profiler(
|
|
203
|
+
fn,
|
|
204
|
+
*inputs,
|
|
205
|
+
trace_filename=None,
|
|
206
|
+
backward=False,
|
|
207
|
+
amp=False,
|
|
208
|
+
amp_dtype=torch.float16,
|
|
209
|
+
cpu=False,
|
|
210
|
+
verbose=True,
|
|
211
|
+
**kwinputs,
|
|
212
|
+
):
|
|
213
|
+
"""Wrap benchmark functions in Pytorch profiler to see CUDA information."""
|
|
214
|
+
if backward:
|
|
215
|
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
216
|
+
out = fn(*inputs, **kwinputs)
|
|
217
|
+
if type(out) is tuple:
|
|
218
|
+
out = out[0]
|
|
219
|
+
g = torch.randn_like(out)
|
|
220
|
+
for _ in range(30): # Warm up
|
|
221
|
+
if backward:
|
|
222
|
+
for x in inputs:
|
|
223
|
+
if isinstance(x, torch.Tensor):
|
|
224
|
+
x.grad = None
|
|
225
|
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
226
|
+
out = fn(*inputs, **kwinputs)
|
|
227
|
+
if type(out) is tuple:
|
|
228
|
+
out = out[0]
|
|
229
|
+
# Backward should be done outside autocast
|
|
230
|
+
if backward:
|
|
231
|
+
out.backward(g, retain_graph=True)
|
|
232
|
+
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
|
|
233
|
+
torch.profiler.ProfilerActivity.CUDA
|
|
234
|
+
]
|
|
235
|
+
with torch.profiler.profile(
|
|
236
|
+
activities=activities,
|
|
237
|
+
record_shapes=True,
|
|
238
|
+
# profile_memory=True,
|
|
239
|
+
with_stack=True,
|
|
240
|
+
) as prof:
|
|
241
|
+
if backward:
|
|
242
|
+
for x in inputs:
|
|
243
|
+
if isinstance(x, torch.Tensor):
|
|
244
|
+
x.grad = None
|
|
245
|
+
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
|
|
246
|
+
out = fn(*inputs, **kwinputs)
|
|
247
|
+
if type(out) is tuple:
|
|
248
|
+
out = out[0]
|
|
249
|
+
if backward:
|
|
250
|
+
out.backward(g, retain_graph=True)
|
|
251
|
+
if verbose:
|
|
252
|
+
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
|
|
253
|
+
print(prof.key_averages().table(row_limit=50))
|
|
254
|
+
if trace_filename is not None:
|
|
255
|
+
prof.export_chrome_trace(trace_filename)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
|
|
259
|
+
torch.cuda.empty_cache()
|
|
260
|
+
torch.cuda.reset_peak_memory_stats()
|
|
261
|
+
torch.cuda.synchronize()
|
|
262
|
+
fn(*inputs, **kwinputs)
|
|
263
|
+
torch.cuda.synchronize()
|
|
264
|
+
mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
|
|
265
|
+
if verbose:
|
|
266
|
+
print(f"{desc} max memory: {mem}GB")
|
|
267
|
+
torch.cuda.empty_cache()
|
|
268
|
+
return mem
|