triton-windows 3.1.0.post17__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (248) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +73 -0
  3. triton/backends/__init__.py +50 -0
  4. triton/backends/amd/compiler.py +262 -0
  5. triton/backends/amd/driver.c +211 -0
  6. triton/backends/amd/driver.py +497 -0
  7. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +358 -0
  8. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +1031 -0
  9. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +1612 -0
  10. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +1337 -0
  11. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +293 -0
  12. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +32 -0
  13. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +174 -0
  14. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +829 -0
  15. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +1809 -0
  16. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +108 -0
  17. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +124 -0
  18. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +405 -0
  19. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +196 -0
  20. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +565 -0
  21. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +2226 -0
  22. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +104 -0
  23. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +244 -0
  24. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +494 -0
  25. triton/backends/amd/include/hip/amd_detail/concepts.hpp +30 -0
  26. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +133 -0
  27. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +218 -0
  28. triton/backends/amd/include/hip/amd_detail/grid_launch.h +67 -0
  29. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +50 -0
  30. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +26 -0
  31. triton/backends/amd/include/hip/amd_detail/helpers.hpp +137 -0
  32. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +1350 -0
  33. triton/backends/amd/include/hip/amd_detail/hip_assert.h +101 -0
  34. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +242 -0
  35. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +254 -0
  36. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +96 -0
  37. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +100 -0
  38. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +10169 -0
  39. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +77 -0
  40. triton/backends/amd/include/hip/amd_detail/host_defines.h +180 -0
  41. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +102 -0
  42. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +798 -0
  43. triton/backends/amd/include/hip/amd_detail/math_fwd.h +698 -0
  44. triton/backends/amd/include/hip/amd_detail/ockl_image.h +177 -0
  45. triton/backends/amd/include/hip/amd_detail/program_state.hpp +107 -0
  46. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +491 -0
  47. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +478 -0
  48. triton/backends/amd/include/hip/channel_descriptor.h +39 -0
  49. triton/backends/amd/include/hip/device_functions.h +38 -0
  50. triton/backends/amd/include/hip/driver_types.h +468 -0
  51. triton/backends/amd/include/hip/hip_bf16.h +36 -0
  52. triton/backends/amd/include/hip/hip_bfloat16.h +44 -0
  53. triton/backends/amd/include/hip/hip_common.h +100 -0
  54. triton/backends/amd/include/hip/hip_complex.h +38 -0
  55. triton/backends/amd/include/hip/hip_cooperative_groups.h +46 -0
  56. triton/backends/amd/include/hip/hip_deprecated.h +95 -0
  57. triton/backends/amd/include/hip/hip_ext.h +159 -0
  58. triton/backends/amd/include/hip/hip_fp16.h +36 -0
  59. triton/backends/amd/include/hip/hip_gl_interop.h +32 -0
  60. triton/backends/amd/include/hip/hip_hcc.h +24 -0
  61. triton/backends/amd/include/hip/hip_math_constants.h +36 -0
  62. triton/backends/amd/include/hip/hip_profile.h +27 -0
  63. triton/backends/amd/include/hip/hip_runtime.h +75 -0
  64. triton/backends/amd/include/hip/hip_runtime_api.h +8919 -0
  65. triton/backends/amd/include/hip/hip_texture_types.h +29 -0
  66. triton/backends/amd/include/hip/hip_vector_types.h +41 -0
  67. triton/backends/amd/include/hip/hip_version.h +17 -0
  68. triton/backends/amd/include/hip/hiprtc.h +421 -0
  69. triton/backends/amd/include/hip/library_types.h +78 -0
  70. triton/backends/amd/include/hip/math_functions.h +42 -0
  71. triton/backends/amd/include/hip/surface_types.h +63 -0
  72. triton/backends/amd/include/hip/texture_types.h +194 -0
  73. triton/backends/amd/include/hsa/Brig.h +1131 -0
  74. triton/backends/amd/include/hsa/amd_hsa_common.h +91 -0
  75. triton/backends/amd/include/hsa/amd_hsa_elf.h +435 -0
  76. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +269 -0
  77. triton/backends/amd/include/hsa/amd_hsa_queue.h +109 -0
  78. triton/backends/amd/include/hsa/amd_hsa_signal.h +80 -0
  79. triton/backends/amd/include/hsa/hsa.h +5729 -0
  80. triton/backends/amd/include/hsa/hsa_amd_tool.h +91 -0
  81. triton/backends/amd/include/hsa/hsa_api_trace.h +566 -0
  82. triton/backends/amd/include/hsa/hsa_ext_amd.h +3090 -0
  83. triton/backends/amd/include/hsa/hsa_ext_finalize.h +531 -0
  84. triton/backends/amd/include/hsa/hsa_ext_image.h +1454 -0
  85. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +488 -0
  86. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +667 -0
  87. triton/backends/amd/include/roctracer/ext/prof_protocol.h +107 -0
  88. triton/backends/amd/include/roctracer/hip_ostream_ops.h +4435 -0
  89. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +1467 -0
  90. triton/backends/amd/include/roctracer/hsa_prof_str.h +3027 -0
  91. triton/backends/amd/include/roctracer/roctracer.h +779 -0
  92. triton/backends/amd/include/roctracer/roctracer_ext.h +81 -0
  93. triton/backends/amd/include/roctracer/roctracer_hcc.h +24 -0
  94. triton/backends/amd/include/roctracer/roctracer_hip.h +37 -0
  95. triton/backends/amd/include/roctracer/roctracer_hsa.h +112 -0
  96. triton/backends/amd/include/roctracer/roctracer_plugin.h +137 -0
  97. triton/backends/amd/include/roctracer/roctracer_roctx.h +67 -0
  98. triton/backends/amd/include/roctracer/roctx.h +229 -0
  99. triton/backends/amd/lib/ockl.bc +0 -0
  100. triton/backends/amd/lib/ocml.bc +0 -0
  101. triton/backends/compiler.py +76 -0
  102. triton/backends/driver.py +34 -0
  103. triton/backends/nvidia/__init__.py +0 -0
  104. triton/backends/nvidia/bin/ptxas.exe +0 -0
  105. triton/backends/nvidia/compiler.py +347 -0
  106. triton/backends/nvidia/driver.c +451 -0
  107. triton/backends/nvidia/driver.py +430 -0
  108. triton/backends/nvidia/include/cuda.h +24359 -0
  109. triton/backends/nvidia/lib/libdevice.10.bc +0 -0
  110. triton/backends/nvidia/lib/x64/cuda.lib +0 -0
  111. triton/compiler/__init__.py +4 -0
  112. triton/compiler/code_generator.py +1302 -0
  113. triton/compiler/compiler.py +416 -0
  114. triton/compiler/errors.py +51 -0
  115. triton/compiler/make_launcher.py +0 -0
  116. triton/errors.py +5 -0
  117. triton/language/__init__.py +284 -0
  118. triton/language/core.py +2621 -0
  119. triton/language/extra/__init__.py +4 -0
  120. triton/language/extra/cuda/__init__.py +8 -0
  121. triton/language/extra/cuda/libdevice.py +1629 -0
  122. triton/language/extra/cuda/utils.py +109 -0
  123. triton/language/extra/hip/__init__.py +3 -0
  124. triton/language/extra/hip/libdevice.py +468 -0
  125. triton/language/extra/libdevice.py +1213 -0
  126. triton/language/math.py +250 -0
  127. triton/language/random.py +207 -0
  128. triton/language/semantic.py +1621 -0
  129. triton/language/standard.py +441 -0
  130. triton/ops/__init__.py +7 -0
  131. triton/ops/blocksparse/__init__.py +7 -0
  132. triton/ops/blocksparse/matmul.py +432 -0
  133. triton/ops/blocksparse/softmax.py +228 -0
  134. triton/ops/cross_entropy.py +96 -0
  135. triton/ops/flash_attention.py +466 -0
  136. triton/ops/matmul.py +219 -0
  137. triton/ops/matmul_perf_model.py +171 -0
  138. triton/runtime/__init__.py +23 -0
  139. triton/runtime/autotuner.py +361 -0
  140. triton/runtime/build.py +129 -0
  141. triton/runtime/cache.py +289 -0
  142. triton/runtime/driver.py +60 -0
  143. triton/runtime/errors.py +26 -0
  144. triton/runtime/interpreter.py +1127 -0
  145. triton/runtime/jit.py +956 -0
  146. triton/runtime/tcc/include/_mingw.h +170 -0
  147. triton/runtime/tcc/include/assert.h +57 -0
  148. triton/runtime/tcc/include/conio.h +409 -0
  149. triton/runtime/tcc/include/ctype.h +281 -0
  150. triton/runtime/tcc/include/dir.h +31 -0
  151. triton/runtime/tcc/include/direct.h +68 -0
  152. triton/runtime/tcc/include/dirent.h +135 -0
  153. triton/runtime/tcc/include/dos.h +55 -0
  154. triton/runtime/tcc/include/errno.h +75 -0
  155. triton/runtime/tcc/include/excpt.h +123 -0
  156. triton/runtime/tcc/include/fcntl.h +52 -0
  157. triton/runtime/tcc/include/fenv.h +108 -0
  158. triton/runtime/tcc/include/float.h +57 -0
  159. triton/runtime/tcc/include/inttypes.h +297 -0
  160. triton/runtime/tcc/include/io.h +418 -0
  161. triton/runtime/tcc/include/limits.h +111 -0
  162. triton/runtime/tcc/include/locale.h +91 -0
  163. triton/runtime/tcc/include/malloc.h +181 -0
  164. triton/runtime/tcc/include/math.h +737 -0
  165. triton/runtime/tcc/include/mem.h +13 -0
  166. triton/runtime/tcc/include/memory.h +40 -0
  167. triton/runtime/tcc/include/process.h +176 -0
  168. triton/runtime/tcc/include/sec_api/conio_s.h +42 -0
  169. triton/runtime/tcc/include/sec_api/crtdbg_s.h +19 -0
  170. triton/runtime/tcc/include/sec_api/io_s.h +33 -0
  171. triton/runtime/tcc/include/sec_api/mbstring_s.h +52 -0
  172. triton/runtime/tcc/include/sec_api/search_s.h +25 -0
  173. triton/runtime/tcc/include/sec_api/stdio_s.h +145 -0
  174. triton/runtime/tcc/include/sec_api/stdlib_s.h +67 -0
  175. triton/runtime/tcc/include/sec_api/stralign_s.h +30 -0
  176. triton/runtime/tcc/include/sec_api/string_s.h +41 -0
  177. triton/runtime/tcc/include/sec_api/sys/timeb_s.h +34 -0
  178. triton/runtime/tcc/include/sec_api/tchar_s.h +266 -0
  179. triton/runtime/tcc/include/sec_api/time_s.h +61 -0
  180. triton/runtime/tcc/include/sec_api/wchar_s.h +128 -0
  181. triton/runtime/tcc/include/setjmp.h +160 -0
  182. triton/runtime/tcc/include/share.h +28 -0
  183. triton/runtime/tcc/include/signal.h +63 -0
  184. triton/runtime/tcc/include/stdarg.h +79 -0
  185. triton/runtime/tcc/include/stdbool.h +11 -0
  186. triton/runtime/tcc/include/stddef.h +54 -0
  187. triton/runtime/tcc/include/stdint.h +212 -0
  188. triton/runtime/tcc/include/stdio.h +429 -0
  189. triton/runtime/tcc/include/stdlib.h +580 -0
  190. triton/runtime/tcc/include/string.h +164 -0
  191. triton/runtime/tcc/include/sys/fcntl.h +13 -0
  192. triton/runtime/tcc/include/sys/file.h +14 -0
  193. triton/runtime/tcc/include/sys/locking.h +30 -0
  194. triton/runtime/tcc/include/sys/stat.h +290 -0
  195. triton/runtime/tcc/include/sys/time.h +69 -0
  196. triton/runtime/tcc/include/sys/timeb.h +133 -0
  197. triton/runtime/tcc/include/sys/types.h +118 -0
  198. triton/runtime/tcc/include/sys/unistd.h +14 -0
  199. triton/runtime/tcc/include/sys/utime.h +146 -0
  200. triton/runtime/tcc/include/tcc/tcc_libm.h +201 -0
  201. triton/runtime/tcc/include/tcclib.h +80 -0
  202. triton/runtime/tcc/include/tchar.h +1102 -0
  203. triton/runtime/tcc/include/time.h +287 -0
  204. triton/runtime/tcc/include/vadefs.h +11 -0
  205. triton/runtime/tcc/include/values.h +4 -0
  206. triton/runtime/tcc/include/varargs.h +12 -0
  207. triton/runtime/tcc/include/wchar.h +873 -0
  208. triton/runtime/tcc/include/wctype.h +172 -0
  209. triton/runtime/tcc/include/winapi/basetsd.h +149 -0
  210. triton/runtime/tcc/include/winapi/basetyps.h +85 -0
  211. triton/runtime/tcc/include/winapi/guiddef.h +156 -0
  212. triton/runtime/tcc/include/winapi/poppack.h +8 -0
  213. triton/runtime/tcc/include/winapi/pshpack1.h +8 -0
  214. triton/runtime/tcc/include/winapi/pshpack2.h +8 -0
  215. triton/runtime/tcc/include/winapi/pshpack4.h +8 -0
  216. triton/runtime/tcc/include/winapi/pshpack8.h +8 -0
  217. triton/runtime/tcc/include/winapi/winbase.h +2951 -0
  218. triton/runtime/tcc/include/winapi/wincon.h +301 -0
  219. triton/runtime/tcc/include/winapi/windef.h +293 -0
  220. triton/runtime/tcc/include/winapi/windows.h +127 -0
  221. triton/runtime/tcc/include/winapi/winerror.h +3166 -0
  222. triton/runtime/tcc/include/winapi/wingdi.h +4080 -0
  223. triton/runtime/tcc/include/winapi/winnt.h +5835 -0
  224. triton/runtime/tcc/include/winapi/winreg.h +272 -0
  225. triton/runtime/tcc/include/winapi/winuser.h +5651 -0
  226. triton/runtime/tcc/include/winapi/winver.h +160 -0
  227. triton/runtime/tcc/lib/cuda.def +697 -0
  228. triton/runtime/tcc/lib/gdi32.def +337 -0
  229. triton/runtime/tcc/lib/kernel32.def +770 -0
  230. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  231. triton/runtime/tcc/lib/msvcrt.def +1399 -0
  232. triton/runtime/tcc/lib/python3.def +810 -0
  233. triton/runtime/tcc/lib/user32.def +658 -0
  234. triton/runtime/tcc/libtcc.dll +0 -0
  235. triton/runtime/tcc/tcc.exe +0 -0
  236. triton/testing.py +496 -0
  237. triton/tools/__init__.py +0 -0
  238. triton/tools/build_extern.py +365 -0
  239. triton/tools/compile.c +67 -0
  240. triton/tools/compile.h +14 -0
  241. triton/tools/compile.py +145 -0
  242. triton/tools/disasm.py +142 -0
  243. triton/tools/link.py +322 -0
  244. triton/windows_utils.py +373 -0
  245. triton_windows-3.1.0.post17.dist-info/METADATA +41 -0
  246. triton_windows-3.1.0.post17.dist-info/RECORD +248 -0
  247. triton_windows-3.1.0.post17.dist-info/WHEEL +5 -0
  248. triton_windows-3.1.0.post17.dist-info/top_level.txt +14 -0
@@ -0,0 +1,441 @@
1
+ from __future__ import annotations
2
+
3
+ from ..runtime.jit import jit
4
+ from . import core
5
+ from . import math
6
+
7
+ # constexpr utilities (triton metaprogramming sucks)
8
+
9
+
10
+ def _unwrap_if_constexpr(o):
11
+ return o.value if isinstance(o, core.constexpr) else o
12
+
13
+
14
+ def _log2(i: core.constexpr):
15
+ log2 = 0
16
+ n = i.value
17
+ while n > 1:
18
+ n >>= 1
19
+ log2 += 1
20
+ return core.constexpr(log2)
21
+
22
+
23
+ def _is_power_of_two(i: core.constexpr):
24
+ n = i.value
25
+ return core.constexpr((n & (n - 1)) == 0 and n != 0)
26
+
27
+
28
+ # -----------------------
29
+ # Standard library
30
+ # -----------------------
31
+
32
+
33
+ @core._tensor_member_fn
34
+ @jit
35
+ def cdiv(x, div):
36
+ """
37
+ Computes the ceiling division of :code:`x` by :code:`div`
38
+
39
+ :param x: the input number
40
+ :type x: Block
41
+ :param div: the divisor
42
+ :param div: Block
43
+ """
44
+ return (x + div - 1) // div
45
+
46
+
47
+ @core._tensor_member_fn
48
+ @jit
49
+ @math._add_math_1arg_docstr("sigmoid")
50
+ def sigmoid(x):
51
+ return 1 / (1 + math.exp(-x))
52
+
53
+
54
+ @core._tensor_member_fn
55
+ @jit
56
+ @math._add_math_1arg_docstr("softmax")
57
+ def softmax(x, ieee_rounding=False):
58
+ z = x - max(x, 0)
59
+ num = math.exp(z)
60
+ den = sum(num, 0)
61
+ return math.fdiv(num, den, ieee_rounding)
62
+
63
+
64
+ @core._tensor_member_fn
65
+ @jit
66
+ def ravel(x):
67
+ """
68
+ Returns a contiguous flattened view of :code:`x`.
69
+
70
+ :param x: the input tensor
71
+ :type x: Block
72
+ """
73
+ return core.reshape(x, [x.numel], can_reorder=True)
74
+
75
+
76
+ @jit
77
+ def swizzle2d(i, j, size_i, size_j, size_g):
78
+ """
79
+ Transforms indices of a row-major :code:`size_i * size_j` matrix into those
80
+ of one where the indices are col-major for each group of :code:`size_g`
81
+ rows.
82
+
83
+ For example, for :code:`size_i = size_j = 4` and :code:`size_g = 2`, it will
84
+ transform ::
85
+
86
+ [[0 , 1 , 2 , 3 ],
87
+ [4 , 5 , 6 , 7 ],
88
+ [8 , 9 , 10, 11],
89
+ [12, 13, 14, 15]]
90
+
91
+ into ::
92
+
93
+ [[0, 2, 4 , 6 ],
94
+ [1, 3, 5 , 7 ],
95
+ [8, 10, 12, 14],
96
+ [9, 11, 13, 15]]
97
+ """
98
+ # "unrolled index in array"
99
+ ij = i * size_j + j
100
+ # number of elements in `size_g` groups
101
+ # of `size_j` columns
102
+ size_gj = size_g * size_j
103
+ # index of the group in which (i,j) is
104
+ group_id = ij // size_gj
105
+ # row-index of the first element of this group
106
+ off_i = group_id * size_g
107
+ # last group may have fewer rows
108
+ size_g = core.minimum(size_i - off_i, size_g)
109
+ # new row and column indices
110
+ new_i = off_i + (ij % size_g)
111
+ new_j = (ij % size_gj) // size_g
112
+ return new_i, new_j
113
+
114
+
115
+ @jit
116
+ def zeros(shape, dtype):
117
+ """
118
+ Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
119
+
120
+ :param shape: Shape of the new array, e.g., (8, 16) or (8, )
121
+ :type shape: tuple of ints
122
+ :param dtype: Data-type of the new array, e.g., :code:`tl.float16`
123
+ :type dtype: DType
124
+ """
125
+ return core.full(shape, 0, dtype)
126
+
127
+
128
+ @jit
129
+ def zeros_like(input):
130
+ """
131
+ Creates a tensor of zeros with the same shape and type as a given tensor.
132
+ """
133
+ return zeros(input.shape, input.dtype)
134
+
135
+
136
+ # max and argmax
137
+
138
+
139
+ @jit
140
+ def _argmax_combine(value1, index1, value2, index2, tie_break_left):
141
+ if tie_break_left:
142
+ tie = value1 == value2 and index1 < index2
143
+ else:
144
+ tie = False
145
+ gt = value1 > value2 or tie
146
+ v_ret = core.where(gt, value1, value2)
147
+ i_ret = core.where(gt, index1, index2)
148
+ return v_ret, i_ret
149
+
150
+
151
+ @jit
152
+ def _argmax_combine_tie_break_left(value1, index1, value2, index2):
153
+ return _argmax_combine(value1, index1, value2, index2, True)
154
+
155
+
156
+ @jit
157
+ def _argmax_combine_tie_break_fast(value1, index1, value2, index2):
158
+ return _argmax_combine(value1, index1, value2, index2, False)
159
+
160
+
161
+ @jit
162
+ def _elementwise_max(a, b):
163
+ return core.maximum(a, b)
164
+
165
+
166
+ @core._tensor_member_fn
167
+ @jit
168
+ @core._add_reduction_docstr("maximum", return_indices_arg="return_indices",
169
+ tie_break_arg="return_indices_tie_break_left")
170
+ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
171
+ input = core._promote_bfloat16_to_float32(input)
172
+ if return_indices:
173
+ if return_indices_tie_break_left:
174
+ return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_left, keep_dims=keep_dims)
175
+ else:
176
+ return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast, keep_dims=keep_dims)
177
+ else:
178
+ if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
179
+ if core.constexpr(input.dtype.is_floating()):
180
+ input = input.to(core.float32)
181
+ else:
182
+ assert input.dtype.is_int(), "Expecting input to be integer type"
183
+ input = input.to(core.int32)
184
+ return core.reduce(input, axis, _elementwise_max, keep_dims=keep_dims)
185
+
186
+
187
+ @core._tensor_member_fn
188
+ @jit
189
+ @core._add_reduction_docstr("maximum index", tie_break_arg="tie_break_left")
190
+ def argmax(input, axis, tie_break_left=True, keep_dims=False):
191
+ (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
192
+ return ret
193
+
194
+
195
+ # min and argmin
196
+
197
+
198
+ @jit
199
+ def _argmin_combine(value1, index1, value2, index2, tie_break_left):
200
+ if tie_break_left:
201
+ tie = value1 == value2 and index1 < index2
202
+ else:
203
+ tie = False
204
+ lt = value1 < value2 or tie
205
+ value_ret = core.where(lt, value1, value2)
206
+ index_ret = core.where(lt, index1, index2)
207
+ return value_ret, index_ret
208
+
209
+
210
+ @jit
211
+ def _argmin_combine_tie_break_left(value1, index1, value2, index2):
212
+ return _argmin_combine(value1, index1, value2, index2, True)
213
+
214
+
215
+ @jit
216
+ def _argmin_combine_tie_break_fast(value1, index1, value2, index2):
217
+ return _argmin_combine(value1, index1, value2, index2, False)
218
+
219
+
220
+ @jit
221
+ def _elementwise_min(a, b):
222
+ return core.minimum(a, b)
223
+
224
+
225
+ @core._tensor_member_fn
226
+ @jit
227
+ @core._add_reduction_docstr("minimum", return_indices_arg="return_indices",
228
+ tie_break_arg="return_indices_tie_break_left")
229
+ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False):
230
+ input = core._promote_bfloat16_to_float32(input)
231
+ if return_indices:
232
+ if return_indices_tie_break_left:
233
+ return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_left, keep_dims=keep_dims)
234
+ else:
235
+ return core._reduce_with_indices(input, axis, _argmin_combine_tie_break_fast, keep_dims=keep_dims)
236
+ else:
237
+ if core.constexpr(input.dtype.primitive_bitwidth) < 32:
238
+ if core.constexpr(input.dtype.is_floating()):
239
+ input = input.to(core.float32)
240
+ else:
241
+ assert input.dtype.is_int(), "Expecting input to be integer type"
242
+ input = input.to(core.int32)
243
+ return core.reduce(input, axis, _elementwise_min, keep_dims=keep_dims)
244
+
245
+
246
+ @core._tensor_member_fn
247
+ @jit
248
+ @core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left")
249
+ def argmin(input, axis, tie_break_left=True, keep_dims=False):
250
+ _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left, keep_dims=keep_dims)
251
+ return ret
252
+
253
+
254
+ @jit
255
+ def _sum_combine(a, b):
256
+ return a + b
257
+
258
+
259
+ # sum
260
+
261
+
262
+ @core._tensor_member_fn
263
+ @jit
264
+ @core._add_reduction_docstr("sum")
265
+ def sum(input, axis=None, keep_dims=False):
266
+ input = core._promote_bfloat16_to_float32(input)
267
+ return core.reduce(input, axis, _sum_combine, keep_dims=keep_dims)
268
+
269
+
270
+ @jit
271
+ def _xor_combine(a, b):
272
+ return a ^ b
273
+
274
+
275
+ # xor sum
276
+
277
+
278
+ @core._tensor_member_fn
279
+ @core.builtin
280
+ @core._add_reduction_docstr("xor sum")
281
+ def xor_sum(input, axis=None, keep_dims=False, _builder=None, _generator=None):
282
+ scalar_ty = input.type.scalar
283
+ if not scalar_ty.is_int():
284
+ raise ValueError("xor_sum only supported for integers")
285
+
286
+ input = core._promote_bfloat16_to_float32(input, _builder=_builder)
287
+ return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims, _builder=_builder, _generator=_generator)
288
+
289
+
290
+ # cumsum
291
+
292
+
293
+ @core._tensor_member_fn
294
+ @jit
295
+ @core._add_scan_docstr("cumsum")
296
+ def cumsum(input, axis=0, reverse=False):
297
+ # todo rename this to a generic function name
298
+ input = core._promote_bfloat16_to_float32(input)
299
+ return core.associative_scan(input, axis, _sum_combine, reverse)
300
+
301
+
302
+ # cumprod
303
+
304
+
305
+ @jit
306
+ def _prod_combine(a, b):
307
+ return a * b
308
+
309
+
310
+ @core._tensor_member_fn
311
+ @jit
312
+ @core._add_scan_docstr("cumprod")
313
+ def cumprod(input, axis=0, reverse=False):
314
+ # todo rename this to a generic function name
315
+ input = core._promote_bfloat16_to_float32(input)
316
+ return core.associative_scan(input, axis, _prod_combine, reverse)
317
+
318
+
319
+ # sort
320
+
321
+
322
+ @jit
323
+ def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
324
+ n_outer: core.constexpr = x.numel >> n_dims
325
+ shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
326
+ y = core.reshape(x, shape)
327
+ # slice left/right with 'stride' 2**(n_dims - i - 1)
328
+ mask = core.arange(0, 2)[None, :, None]
329
+ left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
330
+ right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
331
+ left = core.reshape(left, x.shape)
332
+ right = core.reshape(right, x.shape)
333
+ # actual compare-and-swap
334
+ idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
335
+ ileft = left.to(idtype, bitcast=True)
336
+ iright = right.to(idtype, bitcast=True)
337
+ ix = x.to(idtype, bitcast=True)
338
+ ret = ix ^ core.where((left > right) ^ flip, ileft ^ iright, zeros_like(ix))
339
+ return ret.to(x.dtype, bitcast=True)
340
+
341
+
342
+ @jit
343
+ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
344
+ '''
345
+ order_type 0 == ascending
346
+ order_type 1 == descending
347
+ order_type 2 == alternating
348
+ '''
349
+ n_outer: core.constexpr = x.numel >> n_dims
350
+ core.static_assert(stage <= n_dims)
351
+ # flip denotes whether to re-arrange sub-sequences of elements in ascending or
352
+ # descending order.
353
+ # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
354
+ # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
355
+ # a stride of 2) at this stage
356
+ if order == 2:
357
+ shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
358
+ flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
359
+ else:
360
+ flip = order
361
+ # perform `stage` rounds of `compare-and-swap`
362
+ for i in core.static_range(stage):
363
+ x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims)
364
+ return x
365
+
366
+
367
+ @core._tensor_member_fn
368
+ @jit
369
+ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
370
+ # handle default dimension or check that it is the most minor dim
371
+ _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
372
+ core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
373
+ # iteratively run bitonic merge-sort steps
374
+ n_dims: core.constexpr = _log2(x.shape[_dim])
375
+ for i in core.static_range(1, n_dims + 1):
376
+ x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims)
377
+ return x
378
+
379
+
380
+ # flip
381
+
382
+
383
+ def _get_flip_dim(dim, shape):
384
+ dim = _unwrap_if_constexpr(dim)
385
+ shape = _unwrap_if_constexpr(shape)
386
+ if dim is None:
387
+ dim = len(shape) - 1
388
+ assert dim == len(shape) - 1, "Currently only support flipping the last dimension"
389
+ return core.constexpr(dim)
390
+
391
+
392
+ @core._tensor_member_fn
393
+ @jit
394
+ def flip(x, dim=None):
395
+ """
396
+ Flips a tensor `x` along the dimension `dim`.
397
+
398
+ :param x: the first input tensor
399
+ :type x: Block
400
+ :param dim: the dimension to flip along (currently only final dimension supported)
401
+ :type dim: int
402
+ """
403
+ core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
404
+ core.static_assert(_is_power_of_two(x.numel))
405
+ # # reshape the tensor to have all dimensions be 2.
406
+ # # TODO: We shouldn't have to change the dimensions not sorted.
407
+ steps: core.constexpr = _log2(x.numel)
408
+ start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
409
+ y = core.reshape(x, [2] * steps)
410
+ y = core.expand_dims(y, start)
411
+ flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
412
+ for i in core.static_range(start, steps):
413
+ flip2 = flip
414
+ for j in core.static_range(0, steps + 1):
415
+ if j != i and j != i + 1:
416
+ flip2 = core.expand_dims(flip2, j)
417
+ y = sum(y * flip2, i + 1, keep_dims=True)
418
+ x = core.reshape(y, x.shape)
419
+ return x
420
+
421
+
422
+ @jit
423
+ def interleave(a, b):
424
+ """
425
+ Interleaves the values of two tensors along their last dimension.
426
+
427
+ The two tensors must have the same shape.
428
+
429
+ Equivalent to `tl.join(a, b).reshape(a.shape[-1:] + [2 * a.shape[-1]])`
430
+ """
431
+ c = core.join(a, b)
432
+
433
+ assert isinstance(c.shape, list)
434
+ if len(c.shape) == 1:
435
+ # We must have interleaved two scalars.
436
+ return c
437
+ else:
438
+ # This `else` is necessary because Triton's AST parser doesn't
439
+ # understand that if we take the `if` above we definitely don't run this
440
+ # `else`.
441
+ return core.reshape(c, c.shape[:-2] + [2 * c.shape[-2]])
triton/ops/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ # from .conv import _conv, conv
2
+ from . import blocksparse
3
+ from .cross_entropy import _cross_entropy, cross_entropy
4
+ from .flash_attention import attention
5
+ from .matmul import _matmul, get_higher_dtype, matmul
6
+
7
+ __all__ = ["blocksparse", "_cross_entropy", "cross_entropy", "_matmul", "matmul", "attention", "get_higher_dtype"]
@@ -0,0 +1,7 @@
1
+ from .matmul import matmul
2
+ from .softmax import softmax
3
+
4
+ __all__ = [
5
+ "matmul",
6
+ "softmax",
7
+ ]