triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +73 -0
- triton/backends/__init__.py +50 -0
- triton/backends/amd/compiler.py +262 -0
- triton/backends/amd/driver.c +211 -0
- triton/backends/amd/driver.py +497 -0
- triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
- triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
- triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
- triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
- triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
- triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
- triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
- triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
- triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
- triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
- triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
- triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
- triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
- triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
- triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
- triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
- triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
- triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
- triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
- triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
- triton/backends/amd/include/hip/channel_descriptor.h +39 -0
- triton/backends/amd/include/hip/device_functions.h +38 -0
- triton/backends/amd/include/hip/driver_types.h +468 -0
- triton/backends/amd/include/hip/hip_bf16.h +36 -0
- triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
- triton/backends/amd/include/hip/hip_common.h +100 -0
- triton/backends/amd/include/hip/hip_complex.h +38 -0
- triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
- triton/backends/amd/include/hip/hip_deprecated.h +95 -0
- triton/backends/amd/include/hip/hip_ext.h +159 -0
- triton/backends/amd/include/hip/hip_fp16.h +36 -0
- triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
- triton/backends/amd/include/hip/hip_hcc.h +24 -0
- triton/backends/amd/include/hip/hip_math_constants.h +36 -0
- triton/backends/amd/include/hip/hip_profile.h +27 -0
- triton/backends/amd/include/hip/hip_runtime.h +75 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
- triton/backends/amd/include/hip/hip_texture_types.h +29 -0
- triton/backends/amd/include/hip/hip_vector_types.h +41 -0
- triton/backends/amd/include/hip/hip_version.h +17 -0
- triton/backends/amd/include/hip/hiprtc.h +421 -0
- triton/backends/amd/include/hip/library_types.h +78 -0
- triton/backends/amd/include/hip/math_functions.h +42 -0
- triton/backends/amd/include/hip/surface_types.h +63 -0
- triton/backends/amd/include/hip/texture_types.h +194 -0
- triton/backends/amd/include/hsa/Brig.h +1131 -0
- triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
- triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
- triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
- triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
- triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
- triton/backends/amd/include/hsa/hsa.h +5729 -0
- triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
- triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
- triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
- triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
- triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
- triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
- triton/backends/amd/include/roctracer/roctracer.h +779 -0
- triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
- triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
- triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
- triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
- triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
- triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
- triton/backends/amd/include/roctracer/roctx.h +229 -0
- triton/backends/amd/lib/ockl.bc +0 -0
- triton/backends/amd/lib/ocml.bc +0 -0
- triton/backends/compiler.py +76 -0
- triton/backends/driver.py +34 -0
- triton/backends/nvidia/__init__.py +0 -0
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +347 -0
- triton/backends/nvidia/driver.c +451 -0
- triton/backends/nvidia/driver.py +430 -0
- triton/backends/nvidia/include/cuda.h +24359 -0
- triton/backends/nvidia/lib/libdevice.10.bc +0 -0
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +4 -0
- triton/compiler/code_generator.py +1302 -0
- triton/compiler/compiler.py +416 -0
- triton/compiler/errors.py +51 -0
- triton/compiler/make_launcher.py +0 -0
- triton/errors.py +5 -0
- triton/language/__init__.py +284 -0
- triton/language/core.py +2621 -0
- triton/language/extra/__init__.py +4 -0
- triton/language/extra/cuda/__init__.py +8 -0
- triton/language/extra/cuda/libdevice.py +1629 -0
- triton/language/extra/cuda/utils.py +109 -0
- triton/language/extra/hip/__init__.py +3 -0
- triton/language/extra/hip/libdevice.py +468 -0
- triton/language/extra/libdevice.py +1213 -0
- triton/language/math.py +250 -0
- triton/language/random.py +207 -0
- triton/language/semantic.py +1621 -0
- triton/language/standard.py +441 -0
- triton/ops/__init__.py +7 -0
- triton/ops/blocksparse/__init__.py +7 -0
- triton/ops/blocksparse/matmul.py +432 -0
- triton/ops/blocksparse/softmax.py +228 -0
- triton/ops/cross_entropy.py +96 -0
- triton/ops/flash_attention.py +466 -0
- triton/ops/matmul.py +219 -0
- triton/ops/matmul_perf_model.py +171 -0
- triton/runtime/__init__.py +23 -0
- triton/runtime/autotuner.py +361 -0
- triton/runtime/build.py +129 -0
- triton/runtime/cache.py +289 -0
- triton/runtime/driver.py +60 -0
- triton/runtime/errors.py +26 -0
- triton/runtime/interpreter.py +1127 -0
- triton/runtime/jit.py +956 -0
- triton/runtime/tcc/include/_mingw.h +170 -0
- triton/runtime/tcc/include/assert.h +57 -0
- triton/runtime/tcc/include/conio.h +409 -0
- triton/runtime/tcc/include/ctype.h +281 -0
- triton/runtime/tcc/include/dir.h +31 -0
- triton/runtime/tcc/include/direct.h +68 -0
- triton/runtime/tcc/include/dirent.h +135 -0
- triton/runtime/tcc/include/dos.h +55 -0
- triton/runtime/tcc/include/errno.h +75 -0
- triton/runtime/tcc/include/excpt.h +123 -0
- triton/runtime/tcc/include/fcntl.h +52 -0
- triton/runtime/tcc/include/fenv.h +108 -0
- triton/runtime/tcc/include/float.h +57 -0
- triton/runtime/tcc/include/inttypes.h +297 -0
- triton/runtime/tcc/include/io.h +418 -0
- triton/runtime/tcc/include/limits.h +111 -0
- triton/runtime/tcc/include/locale.h +91 -0
- triton/runtime/tcc/include/malloc.h +181 -0
- triton/runtime/tcc/include/math.h +737 -0
- triton/runtime/tcc/include/mem.h +13 -0
- triton/runtime/tcc/include/memory.h +40 -0
- triton/runtime/tcc/include/process.h +176 -0
- triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
- triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
- triton/runtime/tcc/include/sec_api/io_s.h +33 -0
- triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
- triton/runtime/tcc/include/sec_api/search_s.h +25 -0
- triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
- triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
- triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
- triton/runtime/tcc/include/sec_api/string_s.h +41 -0
- triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
- triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
- triton/runtime/tcc/include/sec_api/time_s.h +61 -0
- triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
- triton/runtime/tcc/include/setjmp.h +160 -0
- triton/runtime/tcc/include/share.h +28 -0
- triton/runtime/tcc/include/signal.h +63 -0
- triton/runtime/tcc/include/stdarg.h +79 -0
- triton/runtime/tcc/include/stdbool.h +11 -0
- triton/runtime/tcc/include/stddef.h +54 -0
- triton/runtime/tcc/include/stdint.h +212 -0
- triton/runtime/tcc/include/stdio.h +429 -0
- triton/runtime/tcc/include/stdlib.h +580 -0
- triton/runtime/tcc/include/string.h +164 -0
- triton/runtime/tcc/include/sys/fcntl.h +13 -0
- triton/runtime/tcc/include/sys/file.h +14 -0
- triton/runtime/tcc/include/sys/locking.h +30 -0
- triton/runtime/tcc/include/sys/stat.h +290 -0
- triton/runtime/tcc/include/sys/time.h +69 -0
- triton/runtime/tcc/include/sys/timeb.h +133 -0
- triton/runtime/tcc/include/sys/types.h +118 -0
- triton/runtime/tcc/include/sys/unistd.h +14 -0
- triton/runtime/tcc/include/sys/utime.h +146 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
- triton/runtime/tcc/include/tcclib.h +80 -0
- triton/runtime/tcc/include/tchar.h +1102 -0
- triton/runtime/tcc/include/time.h +287 -0
- triton/runtime/tcc/include/vadefs.h +11 -0
- triton/runtime/tcc/include/values.h +4 -0
- triton/runtime/tcc/include/varargs.h +12 -0
- triton/runtime/tcc/include/wchar.h +873 -0
- triton/runtime/tcc/include/wctype.h +172 -0
- triton/runtime/tcc/include/winapi/basetsd.h +149 -0
- triton/runtime/tcc/include/winapi/basetyps.h +85 -0
- triton/runtime/tcc/include/winapi/guiddef.h +156 -0
- triton/runtime/tcc/include/winapi/poppack.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
- triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
- triton/runtime/tcc/include/winapi/winbase.h +2951 -0
- triton/runtime/tcc/include/winapi/wincon.h +301 -0
- triton/runtime/tcc/include/winapi/windef.h +293 -0
- triton/runtime/tcc/include/winapi/windows.h +127 -0
- triton/runtime/tcc/include/winapi/winerror.h +3166 -0
- triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
- triton/runtime/tcc/include/winapi/winnt.h +5835 -0
- triton/runtime/tcc/include/winapi/winreg.h +272 -0
- triton/runtime/tcc/include/winapi/winuser.h +5651 -0
- triton/runtime/tcc/include/winapi/winver.h +160 -0
- triton/runtime/tcc/lib/cuda.def +697 -0
- triton/runtime/tcc/lib/gdi32.def +337 -0
- triton/runtime/tcc/lib/kernel32.def +770 -0
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- triton/runtime/tcc/lib/msvcrt.def +1399 -0
- triton/runtime/tcc/lib/python3.def +810 -0
- triton/runtime/tcc/lib/user32.def +658 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/testing.py +496 -0
- triton/tools/__init__.py +0 -0
- triton/tools/build_extern.py +365 -0
- triton/tools/compile.c +67 -0
- triton/tools/compile.h +14 -0
- triton/tools/compile.py +145 -0
- triton/tools/disasm.py +142 -0
- triton/tools/link.py +322 -0
- triton/windows_utils.py +373 -0
- triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
- triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
- triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
- triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from .. import heuristics, jit
|
|
4
|
+
from .. import language as tl
|
|
5
|
+
from .. import next_power_of_2
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def num_warps(N):
|
|
9
|
+
if N < 2048:
|
|
10
|
+
return 4
|
|
11
|
+
elif N < 8192:
|
|
12
|
+
return 8
|
|
13
|
+
return 16
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
|
17
|
+
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
|
18
|
+
@jit
|
|
19
|
+
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
|
20
|
+
row = tl.program_id(0)
|
|
21
|
+
cols = tl.arange(0, BLOCK)
|
|
22
|
+
idx = tl.load(IDX + row)
|
|
23
|
+
# pointers to logit and probs
|
|
24
|
+
LOGITS = LOGITS + row * N + cols
|
|
25
|
+
WRIT_PROBS = PROBS + row * N + cols
|
|
26
|
+
READ_PROBS = PROBS + row * N + idx
|
|
27
|
+
# write-back negative log-probs
|
|
28
|
+
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
|
|
29
|
+
logits = logits.to(tl.float32)
|
|
30
|
+
logits = logits - tl.max(logits, 0)
|
|
31
|
+
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
|
|
32
|
+
tl.store(WRIT_PROBS, probs, mask=cols < N)
|
|
33
|
+
# There is a bug in the compiler, which fails to insert a barrier here.
|
|
34
|
+
# We add it explicitly for now. Will be fixed soon.
|
|
35
|
+
tl.debug_barrier()
|
|
36
|
+
# write-back loss
|
|
37
|
+
probs = tl.load(READ_PROBS)
|
|
38
|
+
tl.store(LOSS + row, probs)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
|
42
|
+
@heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
|
43
|
+
@jit
|
|
44
|
+
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
|
45
|
+
row = tl.program_id(0)
|
|
46
|
+
cols = tl.arange(0, BLOCK)
|
|
47
|
+
idx = tl.load(IDX + row)
|
|
48
|
+
# pointers to probs
|
|
49
|
+
PROBS = PROBS + row * N + cols
|
|
50
|
+
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
|
51
|
+
# and we have -log(p[k]) stored in PROBS, so this is easy
|
|
52
|
+
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
|
|
53
|
+
probs = tl.exp(probs.to(tl.float32))
|
|
54
|
+
delta = cols == idx
|
|
55
|
+
# write result in-place in PROBS
|
|
56
|
+
dout = tl.load(DPROBS + row)
|
|
57
|
+
din = (probs - delta) * dout
|
|
58
|
+
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class _cross_entropy(torch.autograd.Function):
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def forward(cls, ctx, logits, indices):
|
|
65
|
+
# make sure we can use triton
|
|
66
|
+
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
|
|
67
|
+
# make kernel
|
|
68
|
+
device, dtype = logits.device, logits.dtype
|
|
69
|
+
n_cols = logits.shape[-1]
|
|
70
|
+
# run the kernel
|
|
71
|
+
result = torch.empty_like(indices, dtype=dtype, device=device)
|
|
72
|
+
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
|
|
73
|
+
grid = lambda opt: (logits.numel() // n_cols, )
|
|
74
|
+
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
|
|
75
|
+
# save for backward
|
|
76
|
+
ctx.save_for_backward(neg_logprobs, indices)
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def backward(cls, ctx, dneg_logprobs):
|
|
81
|
+
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
|
82
|
+
so we initialize the gradient as neg_logprobs, so we can just exponentiate
|
|
83
|
+
to get p[k], which is most of what we need... neg_logprobs will be
|
|
84
|
+
modified in place to become the gradient we want
|
|
85
|
+
"""
|
|
86
|
+
# load saved tensors
|
|
87
|
+
neg_logprobs, indices = ctx.saved_tensors
|
|
88
|
+
# run the kernel
|
|
89
|
+
# neg_logprobs will be modified in place to become our gradient:
|
|
90
|
+
n_cols = neg_logprobs.shape[-1]
|
|
91
|
+
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
|
|
92
|
+
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
|
|
93
|
+
return neg_logprobs, None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
cross_entropy = _cross_entropy.apply
|
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fused Attention
|
|
3
|
+
===============
|
|
4
|
+
This is a Triton implementation of the Flash Attention algorithm
|
|
5
|
+
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
|
6
|
+
|
|
7
|
+
Sequence Parallel implementation inspired by HazyResearch
|
|
8
|
+
(see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import triton
|
|
13
|
+
|
|
14
|
+
from .. import cdiv, jit
|
|
15
|
+
from .. import language as tl
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def is_hip():
|
|
19
|
+
return triton.runtime.driver.active.get_current_target().backend == "hip"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@jit
|
|
23
|
+
def _fwd_kernel(Q, K, V, sm_scale, #
|
|
24
|
+
L, #
|
|
25
|
+
Out, #
|
|
26
|
+
stride_qz, stride_qh, stride_qm, stride_qk, #
|
|
27
|
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
|
28
|
+
stride_vz, stride_vh, stride_vn, stride_vk, #
|
|
29
|
+
stride_oz, stride_oh, stride_om, stride_on, #
|
|
30
|
+
Z, H, N_CTX, #
|
|
31
|
+
Z_H_N_CTX, #
|
|
32
|
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
|
33
|
+
BLOCK_N: tl.constexpr, #
|
|
34
|
+
IS_CAUSAL: tl.constexpr #
|
|
35
|
+
):
|
|
36
|
+
start_m = tl.program_id(0)
|
|
37
|
+
off_hz = tl.program_id(1)
|
|
38
|
+
qvk_offset = off_hz * stride_qh
|
|
39
|
+
vk_offset = qvk_offset // stride_qm
|
|
40
|
+
|
|
41
|
+
K_block_ptr = tl.make_block_ptr(
|
|
42
|
+
base=K,
|
|
43
|
+
shape=(BLOCK_DMODEL, Z_H_N_CTX),
|
|
44
|
+
strides=(stride_kk, stride_kn),
|
|
45
|
+
offsets=(0, vk_offset),
|
|
46
|
+
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
47
|
+
order=(0, 1),
|
|
48
|
+
)
|
|
49
|
+
V_block_ptr = tl.make_block_ptr(
|
|
50
|
+
base=V,
|
|
51
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
52
|
+
strides=(stride_vn, stride_vk),
|
|
53
|
+
offsets=(vk_offset, 0),
|
|
54
|
+
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
|
55
|
+
order=(1, 0),
|
|
56
|
+
)
|
|
57
|
+
# initialize offsets
|
|
58
|
+
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
59
|
+
offs_n = tl.arange(0, BLOCK_N)
|
|
60
|
+
# initialize pointer to m and l
|
|
61
|
+
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
62
|
+
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
63
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
64
|
+
# credits to: Adam P. Goucher (https://github.com/apgoucher):
|
|
65
|
+
# scale sm_scale by 1/log_2(e) and use
|
|
66
|
+
# 2^x instead of exp in the loop because CSE and LICM
|
|
67
|
+
# don't work as expected with `exp` in the loop
|
|
68
|
+
qk_scale = sm_scale * 1.44269504
|
|
69
|
+
# load q: it will stay in SRAM throughout
|
|
70
|
+
|
|
71
|
+
offs_k = tl.arange(0, BLOCK_DMODEL)
|
|
72
|
+
Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
|
73
|
+
q = tl.load(Q_ptrs)
|
|
74
|
+
|
|
75
|
+
q = (q * qk_scale).to(K.dtype.element_ty)
|
|
76
|
+
lo = 0
|
|
77
|
+
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
|
78
|
+
for start_n in range(lo, hi, BLOCK_N):
|
|
79
|
+
# -- load k, v --
|
|
80
|
+
k = tl.load(K_block_ptr)
|
|
81
|
+
v = tl.load(V_block_ptr)
|
|
82
|
+
# -- compute qk ---
|
|
83
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
84
|
+
if IS_CAUSAL:
|
|
85
|
+
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
|
86
|
+
qk += tl.dot(q, k)
|
|
87
|
+
# -- compute scaling constant ---
|
|
88
|
+
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
|
89
|
+
alpha = tl.math.exp2(m_i - m_i_new)
|
|
90
|
+
p = tl.math.exp2(qk - m_i_new[:, None])
|
|
91
|
+
# -- scale and update acc --
|
|
92
|
+
acc *= alpha[:, None]
|
|
93
|
+
acc += tl.dot(p.to(V.dtype.element_ty), v)
|
|
94
|
+
# -- update m_i and l_i --
|
|
95
|
+
l_i = l_i * alpha + tl.sum(p, 1)
|
|
96
|
+
m_i = m_i_new
|
|
97
|
+
# update pointers
|
|
98
|
+
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
|
99
|
+
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
|
100
|
+
# write back l and m
|
|
101
|
+
acc = acc / l_i[:, None]
|
|
102
|
+
l_ptrs = L + off_hz * N_CTX + offs_m
|
|
103
|
+
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
|
104
|
+
# write back O
|
|
105
|
+
O_block_ptr = tl.make_block_ptr(
|
|
106
|
+
base=Out,
|
|
107
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
108
|
+
strides=(stride_om, stride_on),
|
|
109
|
+
offsets=(vk_offset + start_m * BLOCK_M, 0),
|
|
110
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
111
|
+
order=(1, 0),
|
|
112
|
+
)
|
|
113
|
+
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
|
114
|
+
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@jit
|
|
118
|
+
def _bwd_preprocess(
|
|
119
|
+
Out,
|
|
120
|
+
DO,
|
|
121
|
+
Delta,
|
|
122
|
+
BLOCK_M: tl.constexpr,
|
|
123
|
+
D_HEAD: tl.constexpr,
|
|
124
|
+
):
|
|
125
|
+
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
126
|
+
off_n = tl.arange(0, D_HEAD)
|
|
127
|
+
# load
|
|
128
|
+
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
|
129
|
+
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
|
130
|
+
# compute
|
|
131
|
+
delta = tl.sum(o * do, axis=1)
|
|
132
|
+
# write-back
|
|
133
|
+
tl.store(Delta + off_m, delta)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@jit
|
|
137
|
+
def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
|
|
138
|
+
Out, DO, #
|
|
139
|
+
DQ, DK, DV, #
|
|
140
|
+
L, #
|
|
141
|
+
D, #
|
|
142
|
+
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
|
143
|
+
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
|
144
|
+
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
|
145
|
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
|
146
|
+
stride_vz, stride_vh, stride_vn, stride_vk, #
|
|
147
|
+
Z, H, N_CTX, #
|
|
148
|
+
off_h, off_z, off_hz, start_n, num_block, #
|
|
149
|
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
|
150
|
+
BLOCK_N: tl.constexpr, #
|
|
151
|
+
SEQUENCE_PARALLEL: tl.constexpr, #
|
|
152
|
+
CAUSAL: tl.constexpr, #
|
|
153
|
+
MMA_V3: tl.constexpr #
|
|
154
|
+
):
|
|
155
|
+
if CAUSAL:
|
|
156
|
+
lo = start_n * BLOCK_M
|
|
157
|
+
else:
|
|
158
|
+
lo = 0
|
|
159
|
+
|
|
160
|
+
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
|
|
161
|
+
DQ_offset = off_z * stride_qz + off_h * stride_qh
|
|
162
|
+
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
|
|
163
|
+
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
|
|
164
|
+
if SEQUENCE_PARALLEL:
|
|
165
|
+
DQ_offset += stride_dqa * start_n
|
|
166
|
+
DQ_offset = DQ_offset // stride_qm
|
|
167
|
+
|
|
168
|
+
Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
|
|
169
|
+
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
|
170
|
+
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
|
171
|
+
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
|
|
172
|
+
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
|
|
173
|
+
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
|
174
|
+
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
|
175
|
+
|
|
176
|
+
# initialize row/col offsets
|
|
177
|
+
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
178
|
+
offs_m = tl.arange(0, BLOCK_N)
|
|
179
|
+
# pointer to row-wise quantities in value-like data
|
|
180
|
+
D_ptrs = D + off_hz * N_CTX
|
|
181
|
+
l_ptrs = L + off_hz * N_CTX
|
|
182
|
+
# initialize dv amd dk
|
|
183
|
+
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
184
|
+
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
185
|
+
# k and v stay in SRAM throughout
|
|
186
|
+
k = tl.load(K_block_ptr)
|
|
187
|
+
v = tl.load(V_block_ptr)
|
|
188
|
+
# loop over rows
|
|
189
|
+
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
|
190
|
+
offs_m_curr = start_m + offs_m
|
|
191
|
+
# load q, k, v, do on-chip
|
|
192
|
+
q = tl.load(Q_block_ptr)
|
|
193
|
+
# recompute p = softmax(qk, dim=-1).T
|
|
194
|
+
# NOTE: `do` is pre-divided by `l`; no normalization here
|
|
195
|
+
if CAUSAL:
|
|
196
|
+
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
|
|
197
|
+
else:
|
|
198
|
+
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
199
|
+
qk += tl.dot(q, tl.trans(k))
|
|
200
|
+
qk *= qk_scale
|
|
201
|
+
l_i = tl.load(l_ptrs + offs_m_curr)
|
|
202
|
+
p = tl.math.exp2(qk - l_i[:, None])
|
|
203
|
+
# compute dv
|
|
204
|
+
do = tl.load(DO_block_ptr)
|
|
205
|
+
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
|
206
|
+
# compute dp = dot(v, do)
|
|
207
|
+
Di = tl.load(D_ptrs + offs_m_curr)
|
|
208
|
+
# dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
|
209
|
+
dp = tl.dot(do, tl.trans(v))
|
|
210
|
+
# compute ds = p * (dp - delta[:, None])
|
|
211
|
+
ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty)
|
|
212
|
+
# compute dk = dot(ds.T, q)
|
|
213
|
+
dk += tl.dot(tl.trans(ds), q)
|
|
214
|
+
# compute dq
|
|
215
|
+
if not SEQUENCE_PARALLEL:
|
|
216
|
+
dq = tl.load(DQ_block_ptr)
|
|
217
|
+
dq += tl.dot(ds, k)
|
|
218
|
+
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
|
219
|
+
elif SEQUENCE_PARALLEL:
|
|
220
|
+
if MMA_V3:
|
|
221
|
+
dq = tl.dot(ds, k)
|
|
222
|
+
else:
|
|
223
|
+
# not work with mma v3, because M % 64 != 0
|
|
224
|
+
dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds)))
|
|
225
|
+
tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
|
|
226
|
+
|
|
227
|
+
# increment pointers
|
|
228
|
+
DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
|
|
229
|
+
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
|
|
230
|
+
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
|
|
231
|
+
# write-back
|
|
232
|
+
tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
|
|
233
|
+
tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@jit
|
|
237
|
+
def _bwd_kernel(Q, K, V, sm_scale, #
|
|
238
|
+
Out, DO, #
|
|
239
|
+
DQ, DK, DV, #
|
|
240
|
+
L, #
|
|
241
|
+
D, #
|
|
242
|
+
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
|
243
|
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
|
244
|
+
stride_vz, stride_vh, stride_vn, stride_vk, #
|
|
245
|
+
Z, H, N_CTX, #
|
|
246
|
+
Z_H_N_CTX, #
|
|
247
|
+
SQ_Z_H_N_CTX, #
|
|
248
|
+
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
|
|
249
|
+
BLOCK_N: tl.constexpr, #
|
|
250
|
+
SEQUENCE_PARALLEL: tl.constexpr, #
|
|
251
|
+
CAUSAL: tl.constexpr, #
|
|
252
|
+
MMA_V3: tl.constexpr #
|
|
253
|
+
):
|
|
254
|
+
qk_scale = sm_scale * 1.44269504
|
|
255
|
+
off_hz = tl.program_id(0)
|
|
256
|
+
off_z = off_hz // H
|
|
257
|
+
off_h = off_hz % H
|
|
258
|
+
|
|
259
|
+
Q_block_ptr = tl.make_block_ptr(
|
|
260
|
+
base=Q,
|
|
261
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
262
|
+
strides=(stride_qm, stride_qk),
|
|
263
|
+
offsets=(0, 0),
|
|
264
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
265
|
+
order=(1, 0),
|
|
266
|
+
)
|
|
267
|
+
K_block_ptr = tl.make_block_ptr(
|
|
268
|
+
base=K,
|
|
269
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
270
|
+
strides=(stride_kn, stride_kk),
|
|
271
|
+
offsets=(0, 0),
|
|
272
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
273
|
+
order=(1, 0),
|
|
274
|
+
)
|
|
275
|
+
V_block_ptr = tl.make_block_ptr(
|
|
276
|
+
base=V,
|
|
277
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
278
|
+
strides=(stride_vn, stride_vk),
|
|
279
|
+
offsets=(0, 0),
|
|
280
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
281
|
+
order=(1, 0),
|
|
282
|
+
)
|
|
283
|
+
DO_block_ptr = tl.make_block_ptr(
|
|
284
|
+
base=DO,
|
|
285
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
286
|
+
strides=(stride_qm, stride_qk),
|
|
287
|
+
offsets=(0, 0),
|
|
288
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
289
|
+
order=(1, 0),
|
|
290
|
+
)
|
|
291
|
+
if SEQUENCE_PARALLEL:
|
|
292
|
+
DQ_block_ptr = tl.make_block_ptr(
|
|
293
|
+
base=DQ,
|
|
294
|
+
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
|
|
295
|
+
strides=(stride_qm, stride_qk),
|
|
296
|
+
offsets=(0, 0),
|
|
297
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
298
|
+
order=(1, 0),
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
DQ_block_ptr = tl.make_block_ptr(
|
|
302
|
+
base=DQ,
|
|
303
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
304
|
+
strides=(stride_qm, stride_qk),
|
|
305
|
+
offsets=(0, 0),
|
|
306
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
307
|
+
order=(1, 0),
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
DK_block_ptr = tl.make_block_ptr(
|
|
311
|
+
base=DK,
|
|
312
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
313
|
+
strides=(stride_kn, stride_kk),
|
|
314
|
+
offsets=(0, 0),
|
|
315
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
316
|
+
order=(1, 0),
|
|
317
|
+
)
|
|
318
|
+
DV_block_ptr = tl.make_block_ptr(
|
|
319
|
+
base=DV,
|
|
320
|
+
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
|
321
|
+
strides=(stride_vn, stride_vk),
|
|
322
|
+
offsets=(0, 0),
|
|
323
|
+
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
324
|
+
order=(1, 0),
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
num_block_n = tl.cdiv(N_CTX, BLOCK_N)
|
|
328
|
+
if not SEQUENCE_PARALLEL:
|
|
329
|
+
for start_n in range(0, num_block_n):
|
|
330
|
+
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
|
331
|
+
DQ, DK, DV, #
|
|
332
|
+
L, #
|
|
333
|
+
D, #
|
|
334
|
+
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
|
335
|
+
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
|
336
|
+
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
|
337
|
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
|
338
|
+
stride_vz, stride_vh, stride_vn, stride_vk, #
|
|
339
|
+
Z, H, N_CTX, #
|
|
340
|
+
off_h, off_z, off_hz, start_n, num_block_n, #
|
|
341
|
+
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
|
342
|
+
BLOCK_N=BLOCK_N, #
|
|
343
|
+
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
|
344
|
+
CAUSAL=CAUSAL, #
|
|
345
|
+
MMA_V3=MMA_V3 #
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
start_n = tl.program_id(1)
|
|
349
|
+
_bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
|
|
350
|
+
DQ, DK, DV, #
|
|
351
|
+
L, #
|
|
352
|
+
D, #
|
|
353
|
+
Q_block_ptr, K_block_ptr, V_block_ptr, #
|
|
354
|
+
DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
|
|
355
|
+
stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
|
|
356
|
+
stride_kz, stride_kh, stride_kn, stride_kk, #
|
|
357
|
+
stride_vz, stride_vh, stride_vn, stride_vk, #
|
|
358
|
+
Z, H, N_CTX, #
|
|
359
|
+
off_h, off_z, off_hz, start_n, num_block_n, #
|
|
360
|
+
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
|
|
361
|
+
BLOCK_N=BLOCK_N, #
|
|
362
|
+
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
|
|
363
|
+
CAUSAL=CAUSAL, #
|
|
364
|
+
MMA_V3=MMA_V3 #
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class _attention(torch.autograd.Function):
|
|
369
|
+
|
|
370
|
+
@staticmethod
|
|
371
|
+
def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
|
|
372
|
+
# only support for Ampere now
|
|
373
|
+
capability = torch.cuda.get_device_capability()
|
|
374
|
+
if capability[0] < 8:
|
|
375
|
+
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
|
|
376
|
+
BLOCK_M = 128
|
|
377
|
+
BLOCK_N = 64
|
|
378
|
+
# shape constraints
|
|
379
|
+
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
|
380
|
+
assert Lq == Lk and Lk == Lv
|
|
381
|
+
assert Lk in {16, 32, 64, 128}
|
|
382
|
+
o = torch.empty_like(q)
|
|
383
|
+
grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
|
384
|
+
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
|
385
|
+
num_warps = 4 if Lk <= 64 else 8
|
|
386
|
+
_fwd_kernel[grid](
|
|
387
|
+
q, k, v, sm_scale, #
|
|
388
|
+
L, #
|
|
389
|
+
o, #
|
|
390
|
+
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
|
391
|
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
|
392
|
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
|
393
|
+
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
|
|
394
|
+
q.shape[0], q.shape[1], q.shape[2], #
|
|
395
|
+
q.shape[0] * q.shape[1] * q.shape[2], #
|
|
396
|
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
|
|
397
|
+
IS_CAUSAL=causal, #
|
|
398
|
+
num_warps=num_warps, #
|
|
399
|
+
num_stages=4 #
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
ctx.save_for_backward(q, k, v, o, L)
|
|
403
|
+
ctx.grid = grid
|
|
404
|
+
ctx.sm_scale = sm_scale
|
|
405
|
+
ctx.BLOCK_DMODEL = Lk
|
|
406
|
+
ctx.causal = causal
|
|
407
|
+
ctx.sequence_parallel = sequence_parallel
|
|
408
|
+
return o
|
|
409
|
+
|
|
410
|
+
@staticmethod
|
|
411
|
+
def backward(ctx, do):
|
|
412
|
+
capability = torch.cuda.get_device_capability()
|
|
413
|
+
MMA_V3 = capability[0] >= 9
|
|
414
|
+
BLOCK = 128
|
|
415
|
+
|
|
416
|
+
if is_hip():
|
|
417
|
+
# Bwd pass runs out of shared memory on HIP with larger block size.
|
|
418
|
+
BLOCK = 64
|
|
419
|
+
|
|
420
|
+
q, k, v, o, L = ctx.saved_tensors
|
|
421
|
+
sequence_parallel = ctx.sequence_parallel
|
|
422
|
+
seq_len_kv = k.shape[2]
|
|
423
|
+
do = do.contiguous()
|
|
424
|
+
if sequence_parallel:
|
|
425
|
+
replicas = cdiv(seq_len_kv, BLOCK)
|
|
426
|
+
new_dq_shape = (replicas, ) + q.shape
|
|
427
|
+
dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
|
|
428
|
+
else:
|
|
429
|
+
dq = torch.zeros_like(q, dtype=q.dtype)
|
|
430
|
+
dk = torch.empty_like(k)
|
|
431
|
+
dv = torch.empty_like(v)
|
|
432
|
+
delta = torch.empty_like(L)
|
|
433
|
+
_bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
|
|
434
|
+
o,
|
|
435
|
+
do,
|
|
436
|
+
delta,
|
|
437
|
+
BLOCK_M=BLOCK,
|
|
438
|
+
D_HEAD=ctx.BLOCK_DMODEL,
|
|
439
|
+
)
|
|
440
|
+
_bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
|
|
441
|
+
q, k, v, ctx.sm_scale, #
|
|
442
|
+
o, do, #
|
|
443
|
+
dq, dk, dv, #
|
|
444
|
+
L, #
|
|
445
|
+
delta, #
|
|
446
|
+
o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
|
|
447
|
+
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
|
|
448
|
+
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
|
|
449
|
+
q.shape[0], q.shape[1], q.shape[2], #
|
|
450
|
+
q.shape[0] * q.shape[1] * q.shape[2], #
|
|
451
|
+
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
|
|
452
|
+
BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
|
|
453
|
+
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
|
|
454
|
+
SEQUENCE_PARALLEL=sequence_parallel, #
|
|
455
|
+
CAUSAL=ctx.causal, #
|
|
456
|
+
MMA_V3=MMA_V3, #
|
|
457
|
+
num_warps=8, #
|
|
458
|
+
num_stages=1 #
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
if len(dq.shape) == 5:
|
|
462
|
+
dq = dq.sum(dim=0)
|
|
463
|
+
return dq, dk, dv, None, None, None
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
attention = _attention.apply
|