triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,96 @@
1
+ import torch
2
+
3
+ from .. import heuristics, jit
4
+ from .. import language as tl
5
+ from .. import next_power_of_2
6
+
7
+
8
+ def num_warps(N):
9
+ if N < 2048:
10
+ return 4
11
+ elif N < 8192:
12
+ return 8
13
+ return 16
14
+
15
+
16
+ @heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
17
+ @heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
18
+ @jit
19
+ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
20
+ row = tl.program_id(0)
21
+ cols = tl.arange(0, BLOCK)
22
+ idx = tl.load(IDX + row)
23
+ # pointers to logit and probs
24
+ LOGITS = LOGITS + row * N + cols
25
+ WRIT_PROBS = PROBS + row * N + cols
26
+ READ_PROBS = PROBS + row * N + idx
27
+ # write-back negative log-probs
28
+ logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
29
+ logits = logits.to(tl.float32)
30
+ logits = logits - tl.max(logits, 0)
31
+ probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
32
+ tl.store(WRIT_PROBS, probs, mask=cols < N)
33
+ # There is a bug in the compiler, which fails to insert a barrier here.
34
+ # We add it explicitly for now. Will be fixed soon.
35
+ tl.debug_barrier()
36
+ # write-back loss
37
+ probs = tl.load(READ_PROBS)
38
+ tl.store(LOSS + row, probs)
39
+
40
+
41
+ @heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
42
+ @heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
43
+ @jit
44
+ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
45
+ row = tl.program_id(0)
46
+ cols = tl.arange(0, BLOCK)
47
+ idx = tl.load(IDX + row)
48
+ # pointers to probs
49
+ PROBS = PROBS + row * N + cols
50
+ # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
51
+ # and we have -log(p[k]) stored in PROBS, so this is easy
52
+ probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
53
+ probs = tl.exp(probs.to(tl.float32))
54
+ delta = cols == idx
55
+ # write result in-place in PROBS
56
+ dout = tl.load(DPROBS + row)
57
+ din = (probs - delta) * dout
58
+ tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
59
+
60
+
61
+ class _cross_entropy(torch.autograd.Function):
62
+
63
+ @classmethod
64
+ def forward(cls, ctx, logits, indices):
65
+ # make sure we can use triton
66
+ assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
67
+ # make kernel
68
+ device, dtype = logits.device, logits.dtype
69
+ n_cols = logits.shape[-1]
70
+ # run the kernel
71
+ result = torch.empty_like(indices, dtype=dtype, device=device)
72
+ neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
73
+ grid = lambda opt: (logits.numel() // n_cols, )
74
+ _forward[grid](logits, neg_logprobs, indices, result, n_cols)
75
+ # save for backward
76
+ ctx.save_for_backward(neg_logprobs, indices)
77
+ return result
78
+
79
+ @classmethod
80
+ def backward(cls, ctx, dneg_logprobs):
81
+ """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
82
+ so we initialize the gradient as neg_logprobs, so we can just exponentiate
83
+ to get p[k], which is most of what we need... neg_logprobs will be
84
+ modified in place to become the gradient we want
85
+ """
86
+ # load saved tensors
87
+ neg_logprobs, indices = ctx.saved_tensors
88
+ # run the kernel
89
+ # neg_logprobs will be modified in place to become our gradient:
90
+ n_cols = neg_logprobs.shape[-1]
91
+ grid = lambda opt: (neg_logprobs.numel() // n_cols, )
92
+ _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
93
+ return neg_logprobs, None
94
+
95
+
96
+ cross_entropy = _cross_entropy.apply
@@ -0,0 +1,466 @@
1
+ """
2
+ Fused Attention
3
+ ===============
4
+ This is a Triton implementation of the Flash Attention algorithm
5
+ (see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
6
+
7
+ Sequence Parallel implementation inspired by HazyResearch
8
+ (see https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py)
9
+ """
10
+
11
+ import torch
12
+ import triton
13
+
14
+ from .. import cdiv, jit
15
+ from .. import language as tl
16
+
17
+
18
+ def is_hip():
19
+ return triton.runtime.driver.active.get_current_target().backend == "hip"
20
+
21
+
22
+ @jit
23
+ def _fwd_kernel(Q, K, V, sm_scale, #
24
+ L, #
25
+ Out, #
26
+ stride_qz, stride_qh, stride_qm, stride_qk, #
27
+ stride_kz, stride_kh, stride_kn, stride_kk, #
28
+ stride_vz, stride_vh, stride_vn, stride_vk, #
29
+ stride_oz, stride_oh, stride_om, stride_on, #
30
+ Z, H, N_CTX, #
31
+ Z_H_N_CTX, #
32
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
33
+ BLOCK_N: tl.constexpr, #
34
+ IS_CAUSAL: tl.constexpr #
35
+ ):
36
+ start_m = tl.program_id(0)
37
+ off_hz = tl.program_id(1)
38
+ qvk_offset = off_hz * stride_qh
39
+ vk_offset = qvk_offset // stride_qm
40
+
41
+ K_block_ptr = tl.make_block_ptr(
42
+ base=K,
43
+ shape=(BLOCK_DMODEL, Z_H_N_CTX),
44
+ strides=(stride_kk, stride_kn),
45
+ offsets=(0, vk_offset),
46
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
47
+ order=(0, 1),
48
+ )
49
+ V_block_ptr = tl.make_block_ptr(
50
+ base=V,
51
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
52
+ strides=(stride_vn, stride_vk),
53
+ offsets=(vk_offset, 0),
54
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
55
+ order=(1, 0),
56
+ )
57
+ # initialize offsets
58
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
59
+ offs_n = tl.arange(0, BLOCK_N)
60
+ # initialize pointer to m and l
61
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
62
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
63
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
64
+ # credits to: Adam P. Goucher (https://github.com/apgoucher):
65
+ # scale sm_scale by 1/log_2(e) and use
66
+ # 2^x instead of exp in the loop because CSE and LICM
67
+ # don't work as expected with `exp` in the loop
68
+ qk_scale = sm_scale * 1.44269504
69
+ # load q: it will stay in SRAM throughout
70
+
71
+ offs_k = tl.arange(0, BLOCK_DMODEL)
72
+ Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
73
+ q = tl.load(Q_ptrs)
74
+
75
+ q = (q * qk_scale).to(K.dtype.element_ty)
76
+ lo = 0
77
+ hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
78
+ for start_n in range(lo, hi, BLOCK_N):
79
+ # -- load k, v --
80
+ k = tl.load(K_block_ptr)
81
+ v = tl.load(V_block_ptr)
82
+ # -- compute qk ---
83
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
84
+ if IS_CAUSAL:
85
+ qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
86
+ qk += tl.dot(q, k)
87
+ # -- compute scaling constant ---
88
+ m_i_new = tl.maximum(m_i, tl.max(qk, 1))
89
+ alpha = tl.math.exp2(m_i - m_i_new)
90
+ p = tl.math.exp2(qk - m_i_new[:, None])
91
+ # -- scale and update acc --
92
+ acc *= alpha[:, None]
93
+ acc += tl.dot(p.to(V.dtype.element_ty), v)
94
+ # -- update m_i and l_i --
95
+ l_i = l_i * alpha + tl.sum(p, 1)
96
+ m_i = m_i_new
97
+ # update pointers
98
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
99
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
100
+ # write back l and m
101
+ acc = acc / l_i[:, None]
102
+ l_ptrs = L + off_hz * N_CTX + offs_m
103
+ tl.store(l_ptrs, m_i + tl.math.log2(l_i))
104
+ # write back O
105
+ O_block_ptr = tl.make_block_ptr(
106
+ base=Out,
107
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
108
+ strides=(stride_om, stride_on),
109
+ offsets=(vk_offset + start_m * BLOCK_M, 0),
110
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
111
+ order=(1, 0),
112
+ )
113
+ # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
114
+ tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
115
+
116
+
117
+ @jit
118
+ def _bwd_preprocess(
119
+ Out,
120
+ DO,
121
+ Delta,
122
+ BLOCK_M: tl.constexpr,
123
+ D_HEAD: tl.constexpr,
124
+ ):
125
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
126
+ off_n = tl.arange(0, D_HEAD)
127
+ # load
128
+ o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
129
+ do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
130
+ # compute
131
+ delta = tl.sum(o * do, axis=1)
132
+ # write-back
133
+ tl.store(Delta + off_m, delta)
134
+
135
+
136
+ @jit
137
+ def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, #
138
+ Out, DO, #
139
+ DQ, DK, DV, #
140
+ L, #
141
+ D, #
142
+ Q_block_ptr, K_block_ptr, V_block_ptr, #
143
+ DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
144
+ stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
145
+ stride_kz, stride_kh, stride_kn, stride_kk, #
146
+ stride_vz, stride_vh, stride_vn, stride_vk, #
147
+ Z, H, N_CTX, #
148
+ off_h, off_z, off_hz, start_n, num_block, #
149
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
150
+ BLOCK_N: tl.constexpr, #
151
+ SEQUENCE_PARALLEL: tl.constexpr, #
152
+ CAUSAL: tl.constexpr, #
153
+ MMA_V3: tl.constexpr #
154
+ ):
155
+ if CAUSAL:
156
+ lo = start_n * BLOCK_M
157
+ else:
158
+ lo = 0
159
+
160
+ Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
161
+ DQ_offset = off_z * stride_qz + off_h * stride_qh
162
+ K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
163
+ V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
164
+ if SEQUENCE_PARALLEL:
165
+ DQ_offset += stride_dqa * start_n
166
+ DQ_offset = DQ_offset // stride_qm
167
+
168
+ Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
169
+ K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
170
+ V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
171
+ DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
172
+ DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
173
+ DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
174
+ DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
175
+
176
+ # initialize row/col offsets
177
+ offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
178
+ offs_m = tl.arange(0, BLOCK_N)
179
+ # pointer to row-wise quantities in value-like data
180
+ D_ptrs = D + off_hz * N_CTX
181
+ l_ptrs = L + off_hz * N_CTX
182
+ # initialize dv amd dk
183
+ dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
184
+ dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
185
+ # k and v stay in SRAM throughout
186
+ k = tl.load(K_block_ptr)
187
+ v = tl.load(V_block_ptr)
188
+ # loop over rows
189
+ for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
190
+ offs_m_curr = start_m + offs_m
191
+ # load q, k, v, do on-chip
192
+ q = tl.load(Q_block_ptr)
193
+ # recompute p = softmax(qk, dim=-1).T
194
+ # NOTE: `do` is pre-divided by `l`; no normalization here
195
+ if CAUSAL:
196
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.0), float("-inf"))
197
+ else:
198
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
199
+ qk += tl.dot(q, tl.trans(k))
200
+ qk *= qk_scale
201
+ l_i = tl.load(l_ptrs + offs_m_curr)
202
+ p = tl.math.exp2(qk - l_i[:, None])
203
+ # compute dv
204
+ do = tl.load(DO_block_ptr)
205
+ dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
206
+ # compute dp = dot(v, do)
207
+ Di = tl.load(D_ptrs + offs_m_curr)
208
+ # dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
209
+ dp = tl.dot(do, tl.trans(v))
210
+ # compute ds = p * (dp - delta[:, None])
211
+ ds = (p * (dp - Di[:, None]) * sm_scale).to(Q.dtype.element_ty)
212
+ # compute dk = dot(ds.T, q)
213
+ dk += tl.dot(tl.trans(ds), q)
214
+ # compute dq
215
+ if not SEQUENCE_PARALLEL:
216
+ dq = tl.load(DQ_block_ptr)
217
+ dq += tl.dot(ds, k)
218
+ tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
219
+ elif SEQUENCE_PARALLEL:
220
+ if MMA_V3:
221
+ dq = tl.dot(ds, k)
222
+ else:
223
+ # not work with mma v3, because M % 64 != 0
224
+ dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds)))
225
+ tl.store(DQ_block_ptr, dq.to(Q.dtype.element_ty))
226
+
227
+ # increment pointers
228
+ DQ_block_ptr = tl.advance(DQ_block_ptr, (BLOCK_M, 0))
229
+ Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_M, 0))
230
+ DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_M, 0))
231
+ # write-back
232
+ tl.store(DV_block_ptr, dv.to(V.dtype.element_ty))
233
+ tl.store(DK_block_ptr, dk.to(K.dtype.element_ty))
234
+
235
+
236
+ @jit
237
+ def _bwd_kernel(Q, K, V, sm_scale, #
238
+ Out, DO, #
239
+ DQ, DK, DV, #
240
+ L, #
241
+ D, #
242
+ stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
243
+ stride_kz, stride_kh, stride_kn, stride_kk, #
244
+ stride_vz, stride_vh, stride_vn, stride_vk, #
245
+ Z, H, N_CTX, #
246
+ Z_H_N_CTX, #
247
+ SQ_Z_H_N_CTX, #
248
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, #
249
+ BLOCK_N: tl.constexpr, #
250
+ SEQUENCE_PARALLEL: tl.constexpr, #
251
+ CAUSAL: tl.constexpr, #
252
+ MMA_V3: tl.constexpr #
253
+ ):
254
+ qk_scale = sm_scale * 1.44269504
255
+ off_hz = tl.program_id(0)
256
+ off_z = off_hz // H
257
+ off_h = off_hz % H
258
+
259
+ Q_block_ptr = tl.make_block_ptr(
260
+ base=Q,
261
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
262
+ strides=(stride_qm, stride_qk),
263
+ offsets=(0, 0),
264
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
265
+ order=(1, 0),
266
+ )
267
+ K_block_ptr = tl.make_block_ptr(
268
+ base=K,
269
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
270
+ strides=(stride_kn, stride_kk),
271
+ offsets=(0, 0),
272
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
273
+ order=(1, 0),
274
+ )
275
+ V_block_ptr = tl.make_block_ptr(
276
+ base=V,
277
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
278
+ strides=(stride_vn, stride_vk),
279
+ offsets=(0, 0),
280
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
281
+ order=(1, 0),
282
+ )
283
+ DO_block_ptr = tl.make_block_ptr(
284
+ base=DO,
285
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
286
+ strides=(stride_qm, stride_qk),
287
+ offsets=(0, 0),
288
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
289
+ order=(1, 0),
290
+ )
291
+ if SEQUENCE_PARALLEL:
292
+ DQ_block_ptr = tl.make_block_ptr(
293
+ base=DQ,
294
+ shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
295
+ strides=(stride_qm, stride_qk),
296
+ offsets=(0, 0),
297
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
298
+ order=(1, 0),
299
+ )
300
+ else:
301
+ DQ_block_ptr = tl.make_block_ptr(
302
+ base=DQ,
303
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
304
+ strides=(stride_qm, stride_qk),
305
+ offsets=(0, 0),
306
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
307
+ order=(1, 0),
308
+ )
309
+
310
+ DK_block_ptr = tl.make_block_ptr(
311
+ base=DK,
312
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
313
+ strides=(stride_kn, stride_kk),
314
+ offsets=(0, 0),
315
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
316
+ order=(1, 0),
317
+ )
318
+ DV_block_ptr = tl.make_block_ptr(
319
+ base=DV,
320
+ shape=(Z_H_N_CTX, BLOCK_DMODEL),
321
+ strides=(stride_vn, stride_vk),
322
+ offsets=(0, 0),
323
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
324
+ order=(1, 0),
325
+ )
326
+
327
+ num_block_n = tl.cdiv(N_CTX, BLOCK_N)
328
+ if not SEQUENCE_PARALLEL:
329
+ for start_n in range(0, num_block_n):
330
+ _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
331
+ DQ, DK, DV, #
332
+ L, #
333
+ D, #
334
+ Q_block_ptr, K_block_ptr, V_block_ptr, #
335
+ DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
336
+ stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
337
+ stride_kz, stride_kh, stride_kn, stride_kk, #
338
+ stride_vz, stride_vh, stride_vn, stride_vk, #
339
+ Z, H, N_CTX, #
340
+ off_h, off_z, off_hz, start_n, num_block_n, #
341
+ BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
342
+ BLOCK_N=BLOCK_N, #
343
+ SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
344
+ CAUSAL=CAUSAL, #
345
+ MMA_V3=MMA_V3 #
346
+ )
347
+ else:
348
+ start_n = tl.program_id(1)
349
+ _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, #
350
+ DQ, DK, DV, #
351
+ L, #
352
+ D, #
353
+ Q_block_ptr, K_block_ptr, V_block_ptr, #
354
+ DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, #
355
+ stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, #
356
+ stride_kz, stride_kh, stride_kn, stride_kk, #
357
+ stride_vz, stride_vh, stride_vn, stride_vk, #
358
+ Z, H, N_CTX, #
359
+ off_h, off_z, off_hz, start_n, num_block_n, #
360
+ BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, #
361
+ BLOCK_N=BLOCK_N, #
362
+ SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, #
363
+ CAUSAL=CAUSAL, #
364
+ MMA_V3=MMA_V3 #
365
+ )
366
+
367
+
368
+ class _attention(torch.autograd.Function):
369
+
370
+ @staticmethod
371
+ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False):
372
+ # only support for Ampere now
373
+ capability = torch.cuda.get_device_capability()
374
+ if capability[0] < 8:
375
+ raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
376
+ BLOCK_M = 128
377
+ BLOCK_N = 64
378
+ # shape constraints
379
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
380
+ assert Lq == Lk and Lk == Lv
381
+ assert Lk in {16, 32, 64, 128}
382
+ o = torch.empty_like(q)
383
+ grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
384
+ L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
385
+ num_warps = 4 if Lk <= 64 else 8
386
+ _fwd_kernel[grid](
387
+ q, k, v, sm_scale, #
388
+ L, #
389
+ o, #
390
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
391
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
392
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
393
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
394
+ q.shape[0], q.shape[1], q.shape[2], #
395
+ q.shape[0] * q.shape[1] * q.shape[2], #
396
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, #
397
+ IS_CAUSAL=causal, #
398
+ num_warps=num_warps, #
399
+ num_stages=4 #
400
+ )
401
+
402
+ ctx.save_for_backward(q, k, v, o, L)
403
+ ctx.grid = grid
404
+ ctx.sm_scale = sm_scale
405
+ ctx.BLOCK_DMODEL = Lk
406
+ ctx.causal = causal
407
+ ctx.sequence_parallel = sequence_parallel
408
+ return o
409
+
410
+ @staticmethod
411
+ def backward(ctx, do):
412
+ capability = torch.cuda.get_device_capability()
413
+ MMA_V3 = capability[0] >= 9
414
+ BLOCK = 128
415
+
416
+ if is_hip():
417
+ # Bwd pass runs out of shared memory on HIP with larger block size.
418
+ BLOCK = 64
419
+
420
+ q, k, v, o, L = ctx.saved_tensors
421
+ sequence_parallel = ctx.sequence_parallel
422
+ seq_len_kv = k.shape[2]
423
+ do = do.contiguous()
424
+ if sequence_parallel:
425
+ replicas = cdiv(seq_len_kv, BLOCK)
426
+ new_dq_shape = (replicas, ) + q.shape
427
+ dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype)
428
+ else:
429
+ dq = torch.zeros_like(q, dtype=q.dtype)
430
+ dk = torch.empty_like(k)
431
+ dv = torch.empty_like(v)
432
+ delta = torch.empty_like(L)
433
+ _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )](
434
+ o,
435
+ do,
436
+ delta,
437
+ BLOCK_M=BLOCK,
438
+ D_HEAD=ctx.BLOCK_DMODEL,
439
+ )
440
+ _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)](
441
+ q, k, v, ctx.sm_scale, #
442
+ o, do, #
443
+ dq, dk, dv, #
444
+ L, #
445
+ delta, #
446
+ o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
447
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
448
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
449
+ q.shape[0], q.shape[1], q.shape[2], #
450
+ q.shape[0] * q.shape[1] * q.shape[2], #
451
+ cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], #
452
+ BLOCK_M=BLOCK, BLOCK_N=BLOCK, #
453
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
454
+ SEQUENCE_PARALLEL=sequence_parallel, #
455
+ CAUSAL=ctx.causal, #
456
+ MMA_V3=MMA_V3, #
457
+ num_warps=8, #
458
+ num_stages=1 #
459
+ )
460
+
461
+ if len(dq.shape) == 5:
462
+ dq = dq.sum(dim=0)
463
+ return dq, dk, dv, None, None, None
464
+
465
+
466
+ attention = _attention.apply