triton-windows 3.1.0.post17__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,432 @@
1
+ import torch
2
+
3
+ from ... import cdiv, heuristics, jit
4
+ from ... import language as tl
5
+
6
+ # ********************************************************
7
+ # --------------------------------------------------------
8
+ # Sparse = Dense x Dense (SDD)
9
+ # This operation uses super-blocking to make sure that
10
+ # it's done efficiently when small blocks can be grouped
11
+ # together
12
+ # --------------------------------------------------------
13
+ # ********************************************************
14
+
15
+
16
+ @heuristics({
17
+ 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
18
+ })
19
+ @jit
20
+ def _sdd_kernel(A, B, C, #
21
+ stride_za, stride_ha, stride_ma, stride_ak, #
22
+ stride_zb, stride_hb, stride_bk, stride_nb, #
23
+ stride_zc, stride_hc, stride_mc, stride_nc, #
24
+ K, grid_offset, lut, #
25
+ TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
26
+ BLOCK: tl.constexpr, EVEN_K: tl.constexpr #
27
+ ):
28
+ # ------------ #
29
+ # - Prologue - #
30
+ # ------------ #
31
+ block_id = tl.program_id(0) + grid_offset
32
+ lut += block_id * 3
33
+ # offsets
34
+ off_z = tl.program_id(2) # batch
35
+ off_h = tl.load(lut + 0) # head
36
+
37
+ # initialize pointers to A
38
+ start_am = tl.load(lut + 1)
39
+ offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
40
+ offs_ak = tl.arange(0, TILE_K)
41
+ a_ptrs = A \
42
+ + off_z * stride_za \
43
+ + off_h * stride_ha \
44
+ + offs_am[:, None] * stride_ma \
45
+ + offs_ak[None, :] * stride_ak
46
+ # initialize pointers to B
47
+ start_bn = tl.load(lut + 2)
48
+ offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
49
+ offs_bk = tl.arange(0, TILE_K)
50
+ b_ptrs = B \
51
+ + off_z * stride_zb \
52
+ + off_h * stride_hb \
53
+ + offs_bn[None, :] * stride_nb \
54
+ + offs_bk[:, None] * stride_bk
55
+ # ---------------- #
56
+ # Inner Loop #
57
+ # ---------------- #
58
+ acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
59
+ for k in range(K, 0, -TILE_K):
60
+ if EVEN_K:
61
+ a = tl.load(a_ptrs)
62
+ b = tl.load(b_ptrs)
63
+ else:
64
+ a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
65
+ b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
66
+ acc += tl.dot(a, b, out_dtype=tl.float32)
67
+ a_ptrs += TILE_K * stride_ak
68
+ b_ptrs += TILE_K * stride_bk
69
+ c = acc.to(C.dtype.element_ty)
70
+ # ---------------- #
71
+ # Epilogue #
72
+ # ---------------- #
73
+ offs_cm = tl.arange(0, TILE_M) % BLOCK
74
+ offs_cn = tl.arange(0, TILE_N) % BLOCK
75
+ pc = C \
76
+ + off_z * stride_zc \
77
+ + block_id * stride_hc \
78
+ + offs_cm[:, None] * stride_mc \
79
+ + offs_cn[None, :] * stride_nc
80
+ tl.store(pc, c, mask=True)
81
+
82
+
83
+ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
84
+ if a.stride(2) != 1 and a.stride(3) != 1:
85
+ a = a.contiguous()
86
+ if b.stride(2) != 1 and b.stride(3) != 1:
87
+ b = b.contiguous()
88
+ # (A * B)^T = B^T * A^T
89
+ if trans_c:
90
+ a, b = b, a
91
+ trans_a, trans_b = not trans_b, not trans_a
92
+ # shape constraints
93
+ a_dim = -2 if trans_a else -1
94
+ b_dim = -1 if trans_b else -2
95
+ Ka, Kb = a.shape[a_dim], b.shape[b_dim]
96
+ if Ka != Kb:
97
+ raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
98
+ # allocate output
99
+ if out is None:
100
+ c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
101
+ else:
102
+ assert out.shape == (a.shape[0], lut.shape[0], block, block)
103
+ c = out
104
+ grid = [c.shape[1], 1, c.shape[0]]
105
+ _sdd_kernel[grid](
106
+ a, b, c, #
107
+ a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
108
+ b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
109
+ c.stride(0), c.stride(1), c.stride(2), c.stride(3), #
110
+ Ka, 0, lut, #
111
+ TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, #
112
+ num_warps=4 #
113
+ )
114
+ return c
115
+
116
+
117
+ def sdd_lut(layout, block, device):
118
+ lut = layout.nonzero(as_tuple=False).to(device).int()
119
+ lut = lut.contiguous()
120
+ return lut, None
121
+
122
+
123
+ # -----------------------------
124
+ # Dense = Sparse x Dense (DSD)
125
+ # This operation uses a look-up table that contains pre-computed pointer increments
126
+ # in order to minimize computations in the inner loop of the matmul kernel.
127
+ # -----------------------------
128
+
129
+
130
+ @jit
131
+ def _dsd_kernel(A, B, C, #
132
+ stride_az, stride_ha, stride_am, stride_ak, #
133
+ stride_zb, stride_hb, stride_bk, stride_bn, #
134
+ stride_zc, stride_hc, stride_cm, stride_cn, #
135
+ DS0, DS1, lut, #
136
+ TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, #
137
+ GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr #
138
+ ):
139
+ # ------------ #
140
+ # - Prologue - #
141
+ # ------------ #
142
+ pid_m = tl.program_id(0)
143
+ pid_n = tl.program_id(1)
144
+ num_pid_m = tl.num_programs(0)
145
+ num_pid_n = tl.num_programs(1)
146
+ pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
147
+ pidz = tl.program_id(2)
148
+ header = lut + pid_n * 4
149
+ offset = tl.load(header + 0)
150
+ K = tl.load(header + 1)
151
+ column = tl.load(header + 2)
152
+ off_h = tl.load(header + 3)
153
+ pinc = lut + offset
154
+ # initialize pointers to A (sparse)
155
+ block_id = tl.load(pinc + 1)
156
+ block_id = tl.multiple_of(block_id, 8) # compiler hint
157
+ offs_am = tl.arange(0, TILE_M)
158
+ offs_ak = tl.arange(0, TILE_K)
159
+ pa = A + pidz * stride_az \
160
+ + block_id * stride_ha \
161
+ + offs_am[:, None] * stride_am \
162
+ + offs_ak[None, :] * stride_ak
163
+ # initialize pointers to B (dense)
164
+ offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
165
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
166
+ start_bk = tl.load(pinc)
167
+ start_bk = tl.multiple_of(start_bk, 8) # compiler hint
168
+ offs_bk = start_bk + tl.arange(0, TILE_K)
169
+ pb = B + pidz * stride_zb \
170
+ + off_h * stride_hb \
171
+ + offs_bn[None, :] * stride_bn \
172
+ + offs_bk[:, None] * stride_bk
173
+ # ---------------- #
174
+ # Inner Loop #
175
+ # ---------------- #
176
+ acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
177
+ pinc += 2
178
+ inc_a = tl.load(pinc + 1)
179
+ inc_a = tl.multiple_of(inc_a, 8)
180
+ inc_b = tl.load(pinc)
181
+ inc_b = tl.multiple_of(inc_b, 8)
182
+ for k in range(K, 0, -TILE_K):
183
+ a = tl.load(pa)
184
+ b = tl.load(pb)
185
+ acc += tl.dot(a, b, out_dtype=tl.float32)
186
+ pa += inc_a
187
+ pb += inc_b * stride_bk
188
+ pinc += 2
189
+ inc_a = tl.load(pinc + 1)
190
+ inc_a = tl.multiple_of(inc_a, 8)
191
+ inc_b = tl.load(pinc)
192
+ inc_b = tl.multiple_of(inc_b, 8)
193
+ c = acc.to(C.dtype.element_ty)
194
+ # initialize pointers to C
195
+ offs_cm = column * TILE_M + tl.arange(0, TILE_M)
196
+ offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
197
+ pc = C \
198
+ + off_h * stride_hc \
199
+ + pidz * stride_zc \
200
+ + offs_cm[:, None] * stride_cm \
201
+ + offs_cn[None, :] * stride_cn
202
+ tl.store(pc, c, mask=offs_cn[None, :] < DS0)
203
+
204
+
205
+ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
206
+ if a.stride(2) != 1 and a.stride(3) != 1:
207
+ a = a.contiguous()
208
+ if b.stride(2) != 1 and b.stride(3) != 1:
209
+ b = b.contiguous()
210
+ # shapes / dtypes
211
+ AS1 = block * spdims[2 if trans_a else 1]
212
+ BS0 = b.size(0)
213
+ BS1 = b.size(1)
214
+ BS3 = b.size(2 if trans_b else 3)
215
+ dtype = a.dtype
216
+ # allocate output
217
+ CS0 = BS0
218
+ CS1 = BS1
219
+ CS2 = BS3 if trans_c else AS1
220
+ CS3 = AS1 if trans_c else BS3
221
+ if out is None:
222
+ c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
223
+ else:
224
+ assert out.shape == (CS0, CS1, CS2, CS3)
225
+ c = out
226
+ # meta-parameter heuristics
227
+ TILE_N = 128
228
+ # compute output
229
+ grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0]
230
+ _dsd_kernel[grid](
231
+ a, b, c, #
232
+ a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), #
233
+ b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), #
234
+ c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), #
235
+ BS3, AS1, lut, #
236
+ TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, #
237
+ num_warps=4, GROUP_SIZE_M=4 #
238
+ )
239
+ # exit()
240
+ return c
241
+
242
+
243
+ def dsd_lut(layout, block, step, trans, device):
244
+ """
245
+ Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
246
+ Example (BLOCK=32, STEP=16)
247
+ [[1, 0, 0, 1, 0],
248
+ [0, 1, 1, 0, 1],
249
+ [1, 0, 1, 0, 0]]
250
+
251
+ Then the offsets for A are
252
+ [0 , 16, 32, 48] <- row 0
253
+ \\----/ \\----/
254
+ col=0 col=3
255
+ [64, 80, 96, 112, 128, 144] <- row 1
256
+ \\----/ \\----/ \\------/
257
+ col=1 col=2 col=3
258
+ [160, 176, 192, 208]
259
+ which leads to increments table
260
+ [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
261
+
262
+ Because B is dense, the offsets are
263
+ [0, 16, 96, 112] <- row 0
264
+ [32, 48, 64, 80] <- row 1
265
+ [0, 16, 64, 80] <- row 2
266
+ """
267
+ sizes = torch.sum(layout, 2 if trans else 1)
268
+ head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
269
+ sizes = sizes.flatten()
270
+ segments = sizes * step
271
+ # pointer increments
272
+ if trans:
273
+ nnz = layout.nonzero(as_tuple=False)
274
+ else:
275
+ nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
276
+ num_blocks = nnz.size(0)
277
+ offsets = torch.zeros_like(sizes)
278
+ offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
279
+ offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
280
+ # -------------------------------
281
+ # dense input pointer increments
282
+ # -------------------------------
283
+ # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
284
+ # that is smaller than the block size, so we need to do a bit of extra work
285
+ # to handle this case
286
+ B_idx = nnz[:, 2] * block
287
+ B_incs = B_idx.clone()
288
+ B_incs[1:] -= B_idx[:-1]
289
+ div = block // step
290
+ B_incs = B_incs.view(-1, 1).repeat(1, div)
291
+ B_incs[:, 1:] = step
292
+ B_incs[:, 0] -= (div - 1) * step
293
+ # first increment for each reduction is actually the offset
294
+ B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
295
+ B_incs = B_incs.view(-1)
296
+ # -------------------------------
297
+ # sparse input pointer increments
298
+ # -------------------------------
299
+ # same as above, except that the increments are in the sparse memory layout
300
+ if trans:
301
+ A_idx = torch.arange(num_blocks, device=layout.device)
302
+ else:
303
+ A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
304
+ current_offset = 0
305
+ for z in range(layout.size(0)):
306
+ layoutw = layout[z, :, :].clone().long()
307
+ msum = layoutw.sum()
308
+ layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
309
+ A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
310
+ current_offset += msum
311
+ A_incs = A_idx * block * block
312
+ A_incs[1:] -= A_idx[:-1] * block * block
313
+ A_incs = A_incs.view(-1, 1).repeat(1, div)
314
+ if trans:
315
+ A_incs[:, 1:] = step
316
+ A_incs[:, 0] -= (div - 1) * step
317
+ else:
318
+ A_incs[:, 1:] = step * block
319
+ A_incs[:, 0] -= (div - 1) * step * block
320
+ A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
321
+ A_incs = A_incs.view(-1)
322
+ # create header
323
+ width = col_id.size(0)
324
+ offsets = offsets * 2 * div + 4 * width
325
+ segments = segments * div
326
+ header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
327
+ # create increments
328
+ incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
329
+ # pad by a factor 2*MAX_NUM_STAGES
330
+ # to accommodate pre-fetching inside the kernel
331
+ pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
332
+ incs = torch.cat((incs, pad))
333
+ # create lut
334
+ lut = torch.cat((header, incs))
335
+ lut = lut.type(torch.int32).to(device)
336
+ # create locks
337
+ return lut, width
338
+
339
+
340
+ # -----------------------------
341
+ # Dense = Dense x Sparse (DDS)
342
+ # -----------------------------
343
+ # AB = (B^T A^T)^T
344
+
345
+
346
+ def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
347
+ return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
348
+
349
+
350
+ ##############
351
+ # MAIN API #
352
+ ##############
353
+
354
+
355
+ class _matmul(torch.autograd.Function):
356
+
357
+ fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
358
+
359
+ @staticmethod
360
+ def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut,
361
+ db_width, out):
362
+ c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
363
+ # save for backward
364
+ ctx.save_for_backward(a, b)
365
+ ctx.da_lut = da_lut
366
+ ctx.da_width = da_width
367
+ ctx.db_lut = db_lut
368
+ ctx.db_width = db_width
369
+ ctx.mode = mode
370
+ ctx.spdims = spdims
371
+ ctx.block = block
372
+ ctx.trans_a = trans_a
373
+ ctx.trans_b = trans_b
374
+ ctx.trans_c = trans_c
375
+ ctx.has_out = out is not None
376
+ return c
377
+
378
+ @staticmethod
379
+ def backward(ctx, dc):
380
+ # saved for backward
381
+ a, b = ctx.saved_tensors
382
+ da, db = None, None
383
+ mode = ctx.mode
384
+ # gradients w.r.t. a
385
+ if ctx.needs_input_grad[0]:
386
+ mode_da = mode[1] + mode[0] + mode[2]
387
+ da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block,
388
+ ctx.da_lut, ctx.da_width)
389
+ # gradients w.r.t. b
390
+ if ctx.needs_input_grad[1]:
391
+ mode_db = mode[2] + mode[1] + mode[0]
392
+ db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block,
393
+ ctx.db_lut, ctx.db_width)
394
+ dout = dc if ctx.has_out else None
395
+ return da, db, None, None, None, \
396
+ None, None, None, None, \
397
+ None, None, None, None, None, dout
398
+
399
+
400
+ class matmul:
401
+
402
+ def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
403
+ if mode not in ['sdd', 'dsd', 'dds']:
404
+ raise NotImplementedError('Supported modes are: sdd, dsd, dds')
405
+ self.block = block
406
+ self.mode = mode
407
+ self.trans_a = trans_a
408
+ self.trans_b = trans_b
409
+ self.trans_c = trans_c
410
+ self.layout = layout
411
+ self.spdims = layout.shape
412
+ step = min(block, 32)
413
+ if self.mode == 'sdd':
414
+ self.c_lut, self.c_width = sdd_lut(layout, block, device)
415
+ self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
416
+ self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
417
+ if self.mode == 'dsd':
418
+ self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
419
+ self.da_lut, self.da_width = sdd_lut(layout, block, device)
420
+ self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
421
+ if self.mode == 'dds':
422
+ self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
423
+ self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
424
+ self.db_lut, self.db_width = sdd_lut(layout, block, device)
425
+
426
+ def __call__(self, a, b, out=None):
427
+ c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, #
428
+ self.c_lut, self.c_width, #
429
+ self.da_lut, self.da_width, #
430
+ self.db_lut, self.db_width, #
431
+ out)
432
+ return c
@@ -0,0 +1,228 @@
1
+ import torch
2
+
3
+ from ... import jit
4
+ from ... import language as tl
5
+ from ... import next_power_of_2
6
+
7
+
8
+ def num_warps(n):
9
+ if n <= 128:
10
+ return 1
11
+ if n <= 256:
12
+ return 2
13
+ if n <= 512:
14
+ return 4
15
+ if n <= 4096:
16
+ return 8
17
+ return 16
18
+
19
+
20
+ @jit
21
+ def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, #
22
+ R, extent, stride_zr, stride_hr, # relative attention
23
+ scale, is_causal, #
24
+ ROW_SIZE: tl.constexpr, #
25
+ BLOCK_SIZE: tl.constexpr, #
26
+ IS_DENSE: tl.constexpr #
27
+ ):
28
+ h = tl.program_id(0)
29
+ m = tl.program_id(1)
30
+ z = tl.program_id(2)
31
+ # create index ranges
32
+ hm = h * tl.num_programs(1) + m
33
+ lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
34
+ block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
35
+ # extract information from LUT
36
+ header = LUT + (hm // BLOCK_SIZE) * 2
37
+ size = tl.load(header + 0)
38
+ offset = tl.load(header + 1)
39
+ # pointer offset
40
+ off_a = z * stride_xz
41
+ off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
42
+ off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
43
+ # do not need to read column indices in the dense case
44
+ if IS_DENSE:
45
+ ns = tl.arange(0, ROW_SIZE)
46
+ else:
47
+ off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
48
+ start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
49
+ ns = start_n * BLOCK_SIZE + lane_n
50
+ # load X
51
+ mask = block_n < size
52
+ a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
53
+ a = a.to(tl.float32)
54
+ # compute
55
+ out = a
56
+ out *= scale
57
+ # apply relative attention
58
+ if R is not None:
59
+ R += z * stride_zr
60
+ R += h * stride_hr
61
+ off_lo = (extent - m - 1) + ns
62
+ mask_lo = (off_lo >= 0) & (off_lo < extent)
63
+ rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
64
+ out += rel_logits
65
+ out = out.to(tl.float32)
66
+ # apply causal mask
67
+ out = tl.where((ns > m) & is_causal, -float("inf"), out)
68
+ # computation
69
+ out = tl.softmax(out)
70
+ # write-back
71
+ tl.store(Out + off_a + lane_n, out, mask=mask)
72
+
73
+
74
+ @jit
75
+ def _blocksparse_softmax_bwd(DA, stride_zdx, #
76
+ DOut, stride_zdout, #
77
+ Out, stride_zout, #
78
+ scale, #
79
+ LUT, #
80
+ DR, extent, stride_zr, stride_hr, stride_er, #
81
+ is_causal, #
82
+ ROW_SIZE: tl.constexpr, #
83
+ BLOCK_SIZE: tl.constexpr, #
84
+ IS_DENSE: tl.constexpr):
85
+ h = tl.program_id(0)
86
+ m = tl.program_id(1)
87
+ z = tl.program_id(2)
88
+ # create index ranges
89
+ hm = h * tl.num_programs(1) + m
90
+ lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
91
+ block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
92
+ # extract information from LUT
93
+ header = LUT + (hm // BLOCK_SIZE) * 2
94
+ size = tl.load(header + 0)
95
+ offset = tl.load(header + 1)
96
+ # row-col offset
97
+ off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
98
+ off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
99
+ mask = block_n < size
100
+ # pointers
101
+ As = Out + z * stride_zout + off_mn
102
+ DOuts = DOut + z * stride_zdout + off_mn
103
+ # do not need to read column indices in the dense case
104
+ if IS_DENSE:
105
+ ns = tl.arange(0, ROW_SIZE)
106
+ else:
107
+ off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
108
+ start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
109
+ ns = start_n * BLOCK_SIZE + lane_n
110
+ # load data
111
+ a = tl.load(As + lane_n, mask=mask, other=0.0)
112
+ a = a.to(tl.float32)
113
+ dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
114
+ dout = dout.to(tl.float32)
115
+ # compute
116
+ a = tl.where((ns > m) & is_causal & (a == a), 0., a)
117
+ da = a * (dout - tl.sum(a * dout, 0))
118
+ # apply relative attention
119
+ if DR is not None:
120
+ DR += z * stride_zr
121
+ DR += h * stride_hr
122
+ off_lo = (extent - m - 1) + ns
123
+ mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
124
+ tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
125
+ da = da * scale
126
+ # convert da
127
+ # write-back
128
+ DAs = DA + z * stride_zdx + off_mn
129
+ tl.store(DAs + lane_n, da, mask=mask)
130
+
131
+
132
+ class _softmax(torch.autograd.Function):
133
+
134
+ @staticmethod
135
+ def make_lut(layout, block, device):
136
+ _empty = torch.tensor([], dtype=torch.int64, device=layout.device)
137
+ sizes = _empty.clone()
138
+ # sizes along rows
139
+ for h in range(layout.shape[0]):
140
+ sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
141
+ total_sizes = sizes * block
142
+ # offsets in block format
143
+ offsets = torch.zeros_like(sizes)
144
+ offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
145
+ # block indices
146
+ columns = layout.nonzero(as_tuple=False)[:, 2]
147
+ header = torch.stack((sizes, offsets), dim=1).view(-1)
148
+ lut = torch.cat((header, columns)).type(torch.int32).to(device)
149
+ return lut, int(total_sizes.max())
150
+
151
+ @staticmethod
152
+ def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense):
153
+ if scale is not None and isinstance(scale, torch.Tensor):
154
+ assert scale.device.type == "cpu"
155
+ scale = scale.item()
156
+ M = a.shape[0]
157
+ grid = [spdims[0], spdims[1] * block, M]
158
+ rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
159
+ rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
160
+ # enqueue kernel
161
+ out = torch.empty_like(a)
162
+ _blocksparse_softmax_fwd[grid](
163
+ out, a, a.stride(0), lut, #
164
+ rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn#
165
+ scale, #
166
+ is_causal, #
167
+ BLOCK_SIZE=block, #
168
+ ROW_SIZE=next_power_of_2(maxlut), #
169
+ IS_DENSE=is_dense, #
170
+ num_warps=num_warps(maxlut) #
171
+ )
172
+ # save to context
173
+ # ctx.mark_dirty(x)
174
+ ctx.save_for_backward(out, lut)
175
+ ctx.spdims = spdims
176
+ ctx.block = block
177
+ ctx.maxlut = maxlut
178
+ ctx.scale = scale
179
+ ctx.rel_shape = rel_shape
180
+ ctx.rel_strides = rel_strides
181
+ ctx.rel_dtype = a.dtype
182
+ ctx.is_dense = is_dense
183
+ ctx.is_causal = is_causal
184
+ return out
185
+
186
+ @staticmethod
187
+ def backward(ctx, dout):
188
+ # retrieve from context
189
+ out, lut = ctx.saved_tensors
190
+ # relative logits gradients
191
+ dr = None
192
+ if ctx.needs_input_grad[3]:
193
+ dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
194
+ # run kernel
195
+ M = out.shape[0]
196
+ grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
197
+ da = torch.empty_like(dout)
198
+ _blocksparse_softmax_bwd[grid](
199
+ da, da.stride(0), #
200
+ dout, dout.stride(0), #
201
+ out, out.stride(0), #
202
+ ctx.scale, #
203
+ lut, #
204
+ dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], #
205
+ ctx.is_causal, #
206
+ BLOCK_SIZE=ctx.block, #
207
+ ROW_SIZE=next_power_of_2(ctx.maxlut), #
208
+ IS_DENSE=ctx.is_dense, #
209
+ num_warps=num_warps(ctx.maxlut) #
210
+ )
211
+ return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
212
+
213
+
214
+ class softmax:
215
+
216
+ def __init__(self, layout, block, device, is_dense=False):
217
+ self.spdims = layout.shape
218
+ self.layout = layout
219
+ self.block = block
220
+ self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
221
+ self.is_dense = is_dense
222
+
223
+ def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
224
+ if rel_logits is not None and rel_logits.dtype != a.dtype:
225
+ raise ValueError(f"relative position embedding must be {a.dtype}")
226
+ a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut,
227
+ self.is_dense)
228
+ return a