triton-windows 3.3.1.post19__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (225) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +11 -2
  3. triton/_filecheck.py +97 -0
  4. triton/_internal_testing.py +95 -18
  5. triton/_utils.py +112 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +161 -119
  9. triton/backends/amd/driver.c +118 -46
  10. triton/backends/amd/driver.py +274 -96
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/driver.py +13 -0
  13. triton/backends/nvidia/bin/ptxas.exe +0 -0
  14. triton/backends/nvidia/compiler.py +163 -106
  15. triton/backends/nvidia/driver.c +166 -101
  16. triton/backends/nvidia/driver.py +384 -202
  17. triton/compiler/__init__.py +5 -2
  18. triton/compiler/code_generator.py +439 -231
  19. triton/compiler/compiler.py +152 -84
  20. triton/experimental/__init__.py +0 -0
  21. triton/experimental/gluon/__init__.py +5 -0
  22. triton/experimental/gluon/_compiler.py +0 -0
  23. triton/experimental/gluon/_runtime.py +102 -0
  24. triton/experimental/gluon/language/__init__.py +119 -0
  25. triton/experimental/gluon/language/_core.py +490 -0
  26. triton/experimental/gluon/language/_layouts.py +583 -0
  27. triton/experimental/gluon/language/_math.py +20 -0
  28. triton/experimental/gluon/language/_semantic.py +380 -0
  29. triton/experimental/gluon/language/_standard.py +80 -0
  30. triton/experimental/gluon/language/amd/__init__.py +4 -0
  31. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  32. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  33. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  34. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  35. triton/experimental/gluon/language/extra/__init__.py +3 -0
  36. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  37. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  38. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  39. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  40. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +387 -0
  41. triton/experimental/gluon/language/nvidia/blackwell/tma.py +52 -0
  42. triton/experimental/gluon/language/nvidia/hopper/__init__.py +132 -0
  43. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +34 -0
  44. triton/experimental/gluon/language/nvidia/hopper/tma.py +97 -0
  45. triton/experimental/gluon/nvidia/__init__.py +4 -0
  46. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  47. triton/experimental/gluon/nvidia/hopper.py +45 -0
  48. triton/knobs.py +546 -0
  49. triton/language/__init__.py +50 -19
  50. triton/language/core.py +909 -572
  51. triton/language/extra/cuda/__init__.py +10 -7
  52. triton/language/extra/cuda/gdc.py +42 -0
  53. triton/language/extra/cuda/libdevice.py +394 -394
  54. triton/language/extra/cuda/utils.py +21 -21
  55. triton/language/extra/hip/__init__.py +3 -1
  56. triton/language/extra/hip/libdevice.py +120 -104
  57. triton/language/extra/hip/utils.py +35 -0
  58. triton/language/extra/libdevice.py +4 -0
  59. triton/language/math.py +65 -66
  60. triton/language/random.py +12 -2
  61. triton/language/semantic.py +1757 -1768
  62. triton/language/standard.py +127 -62
  63. triton/language/target_info.py +54 -0
  64. triton/runtime/_allocation.py +15 -3
  65. triton/runtime/_async_compile.py +55 -0
  66. triton/runtime/autotuner.py +117 -60
  67. triton/runtime/build.py +83 -17
  68. triton/runtime/cache.py +61 -47
  69. triton/runtime/driver.py +25 -47
  70. triton/runtime/interpreter.py +95 -50
  71. triton/runtime/jit.py +445 -248
  72. triton/runtime/tcc/include/_mingw.h +8 -10
  73. triton/runtime/tcc/include/assert.h +5 -0
  74. triton/runtime/tcc/include/errno.h +1 -1
  75. triton/runtime/tcc/include/float.h +21 -3
  76. triton/runtime/tcc/include/iso646.h +36 -0
  77. triton/runtime/tcc/include/limits.h +5 -0
  78. triton/runtime/tcc/include/malloc.h +2 -2
  79. triton/runtime/tcc/include/math.h +21 -261
  80. triton/runtime/tcc/include/stdalign.h +16 -0
  81. triton/runtime/tcc/include/stdarg.h +5 -70
  82. triton/runtime/tcc/include/stdatomic.h +171 -0
  83. triton/runtime/tcc/include/stddef.h +7 -19
  84. triton/runtime/tcc/include/stdlib.h +15 -4
  85. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  86. triton/runtime/tcc/include/sys/stat.h +2 -2
  87. triton/runtime/tcc/include/sys/types.h +5 -0
  88. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  89. triton/runtime/tcc/include/tccdefs.h +342 -0
  90. triton/runtime/tcc/include/tgmath.h +89 -0
  91. triton/runtime/tcc/include/uchar.h +33 -0
  92. triton/runtime/tcc/include/unistd.h +1 -0
  93. triton/runtime/tcc/include/winapi/qos.h +72 -0
  94. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  95. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  96. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  97. triton/runtime/tcc/include/winapi/windows.h +1 -1
  98. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  99. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  100. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  101. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  102. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  103. triton/runtime/tcc/lib/libtcc1.a +0 -0
  104. triton/runtime/tcc/lib/python314.def +1800 -0
  105. triton/runtime/tcc/lib/python314t.def +1809 -0
  106. triton/runtime/tcc/libtcc.dll +0 -0
  107. triton/runtime/tcc/tcc.exe +0 -0
  108. triton/testing.py +16 -12
  109. triton/tools/compile.py +62 -14
  110. triton/tools/disasm.py +3 -4
  111. triton/tools/extra/cuda/compile.c +1 -0
  112. triton/tools/extra/hip/compile.cpp +66 -0
  113. triton/tools/extra/hip/compile.h +13 -0
  114. triton/tools/ragged_tma.py +92 -0
  115. triton/tools/tensor_descriptor.py +34 -0
  116. triton/windows_utils.py +52 -81
  117. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +8 -4
  118. triton_windows-3.5.0.post21.dist-info/RECORD +217 -0
  119. triton_windows-3.5.0.post21.dist-info/entry_points.txt +3 -0
  120. triton_windows-3.5.0.post21.dist-info/licenses/LICENSE +23 -0
  121. triton_windows-3.5.0.post21.dist-info/top_level.txt +1 -0
  122. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  123. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  124. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  125. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  126. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  127. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  128. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  129. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  130. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  131. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  132. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  133. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  134. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  135. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  136. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  137. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  138. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  139. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  140. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  141. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  142. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  143. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  144. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  145. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  146. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  147. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  148. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  149. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  150. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  151. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  152. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  153. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  154. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  155. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  156. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  157. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  158. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  159. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  160. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  161. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  162. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  163. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  164. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  165. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  166. triton/backends/amd/include/hip/device_functions.h +0 -38
  167. triton/backends/amd/include/hip/driver_types.h +0 -468
  168. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  169. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  170. triton/backends/amd/include/hip/hip_common.h +0 -100
  171. triton/backends/amd/include/hip/hip_complex.h +0 -38
  172. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  173. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  174. triton/backends/amd/include/hip/hip_ext.h +0 -161
  175. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  176. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  177. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  178. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  179. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  180. triton/backends/amd/include/hip/hip_profile.h +0 -27
  181. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  182. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  183. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  184. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  185. triton/backends/amd/include/hip/hip_version.h +0 -17
  186. triton/backends/amd/include/hip/hiprtc.h +0 -421
  187. triton/backends/amd/include/hip/library_types.h +0 -78
  188. triton/backends/amd/include/hip/math_functions.h +0 -42
  189. triton/backends/amd/include/hip/surface_types.h +0 -63
  190. triton/backends/amd/include/hip/texture_types.h +0 -194
  191. triton/backends/amd/include/hsa/Brig.h +0 -1131
  192. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  193. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  194. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  195. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  196. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  197. triton/backends/amd/include/hsa/hsa.h +0 -5738
  198. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  199. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  200. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  201. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  202. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  203. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  204. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  205. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  206. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  207. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  208. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  209. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  210. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  211. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  212. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  213. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  214. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  215. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  216. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  217. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  218. triton/backends/amd/include/roctracer/roctx.h +0 -229
  219. triton/language/_utils.py +0 -21
  220. triton/language/extra/cuda/_experimental_tma.py +0 -106
  221. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  222. triton/tools/experimental_descriptor.py +0 -32
  223. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  224. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  225. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
@@ -1,24 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..runtime.jit import jit
3
+ from ..runtime.jit import jit, constexpr_function
4
4
  from . import core
5
5
  from . import math
6
6
 
7
7
  # constexpr utilities
8
8
 
9
9
 
10
- def _log2(i: core.constexpr):
10
+ @constexpr_function
11
+ def _log2(i):
11
12
  log2 = 0
12
- n = i.value
13
+ n = i
13
14
  while n > 1:
14
15
  n >>= 1
15
16
  log2 += 1
16
- return core.constexpr(log2)
17
+ return log2
17
18
 
18
19
 
19
- def _is_power_of_two(i: core.constexpr):
20
- n = i.value
21
- return core.constexpr((n & (n - 1)) == 0 and n != 0)
20
+ @constexpr_function
21
+ def _is_power_of_two(i):
22
+ return (i & (i - 1)) == 0 and i != 0
22
23
 
23
24
 
24
25
  # -----------------------
@@ -50,10 +51,14 @@ def sigmoid(x):
50
51
  @core._tensor_member_fn
51
52
  @jit
52
53
  @math._add_math_1arg_docstr("softmax")
53
- def softmax(x, ieee_rounding=False):
54
- z = x - max(x, 0)
54
+ def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
55
+ if dim is None:
56
+ _dim: core.constexpr = 0
57
+ else:
58
+ _dim: core.constexpr = dim
59
+ z = x - max(x, _dim, keep_dims=keep_dims)
55
60
  num = math.exp(z)
56
- den = sum(num, 0)
61
+ den = sum(num, _dim, keep_dims=keep_dims)
57
62
  return math.fdiv(num, den, ieee_rounding)
58
63
 
59
64
 
@@ -259,8 +264,8 @@ def _sum_combine(a, b):
259
264
  # sum
260
265
 
261
266
 
262
- def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr):
263
- dtype = core._unwrap_if_constexpr(dtype)
267
+ @constexpr_function
268
+ def _pick_sum_dtype(in_dtype, dtype):
264
269
  if dtype is not None:
265
270
  return dtype
266
271
 
@@ -302,15 +307,37 @@ def xor_sum(input, axis=None, keep_dims=False):
302
307
  return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
303
308
 
304
309
 
310
+ # or reduction
311
+
312
+
313
+ @jit
314
+ def _or_combine(x, y):
315
+ return x | y
316
+
317
+
318
+ @core._tensor_member_fn
319
+ @jit
320
+ @core._add_reduction_docstr("reduce_or")
321
+ def reduce_or(input, axis, keep_dims=False):
322
+ core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
323
+ return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
324
+
325
+
305
326
  # cumsum
306
327
 
307
328
 
308
329
  @core._tensor_member_fn
309
330
  @jit
310
- @core._add_scan_docstr("cumsum")
311
- def cumsum(input, axis=0, reverse=False):
331
+ @core._add_scan_docstr("cumsum", dtype_arg="dtype")
332
+ def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
312
333
  # todo rename this to a generic function name
334
+
313
335
  input = core._promote_bfloat16_to_float32(input)
336
+ out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
337
+
338
+ if out_dtype is not None:
339
+ input = input.to(out_dtype)
340
+
314
341
  return core.associative_scan(input, axis, _sum_combine, reverse)
315
342
 
316
343
 
@@ -335,53 +362,63 @@ def cumprod(input, axis=0, reverse=False):
335
362
 
336
363
 
337
364
  @jit
338
- def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
339
- n_outer: core.constexpr = x.numel >> n_dims
340
- shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
341
- y = core.reshape(x, shape)
342
- # slice left/right with 'stride' 2**(n_dims - i - 1)
343
- mask = core.arange(0, 2)[None, :, None]
344
- left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
345
- right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
346
- left = core.reshape(left, x.shape)
347
- right = core.reshape(right, x.shape)
348
- # actual compare-and-swap
365
+ def _indicator(n_dims: core.constexpr, j: core.constexpr):
366
+ ar = core.arange(0, 2)
367
+ ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
368
+ return ar
369
+
370
+
371
+ @jit
372
+ def _compare_and_swap(x, flip, i: core.constexpr):
373
+ # compare-and-swap on the ith *innermost* dimension
374
+ n_dims: core.constexpr = _log2(x.numel)
375
+
376
+ # flip along middle dimension (the bitwise XORs will be optimised away):
349
377
  idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
350
- ileft = left.to(idtype, bitcast=True)
351
- iright = right.to(idtype, bitcast=True)
352
378
  ix = x.to(idtype, bitcast=True)
353
- ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix))
354
- return ret.to(x.dtype, bitcast=True)
379
+ iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
380
+ y = iy.to(x.dtype, bitcast=True)
381
+
382
+ # determines whether we are in the right (rather than left) position along the axis:
383
+ is_right = _indicator(n_dims, i)
384
+
385
+ # conditional swap:
386
+ ret = core.where((x > y) != (flip ^ is_right), y, x)
387
+ return ret
355
388
 
356
389
 
357
390
  @jit
358
- def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
391
+ def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
359
392
  '''
360
393
  order_type 0 == ascending
361
394
  order_type 1 == descending
362
395
  order_type 2 == alternating
363
396
  '''
364
- n_outer: core.constexpr = x.numel >> n_dims
365
- core.static_assert(stage <= n_dims)
366
397
  # flip denotes whether to re-arrange sub-sequences of elements in ascending or
367
398
  # descending order.
368
399
  # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
369
400
  # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
370
401
  # a stride of 2) at this stage
371
402
  if order == 2:
372
- shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
373
- flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
403
+ flip = _indicator(_log2(x.numel), stage)
374
404
  else:
375
405
  flip = order
376
406
  # perform `stage` rounds of `compare-and-swap`
377
407
  for i in core.static_range(stage):
378
- x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims)
408
+ x = _compare_and_swap(x, flip, stage - 1 - i)
379
409
  return x
380
410
 
381
411
 
382
- @core._tensor_member_fn
383
412
  @jit
384
- def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
413
+ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
414
+ h = core.reshape(x, [2] * _log2(x.numel))
415
+ h = _bitonic_merge_hypercube(h, stage, order)
416
+ x = core.reshape(h, x.shape)
417
+ return x
418
+
419
+
420
+ @jit
421
+ def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
385
422
  """
386
423
  Sorts a tensor along a specified dimension.
387
424
 
@@ -389,29 +426,64 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE
389
426
  :type x: Tensor
390
427
  :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
391
428
  :type dim: int, optional
429
+ :param k: the number of top elements to select. If none, assume k = x.shape[dim]
430
+ :type k: int, optional
392
431
  :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
393
432
  :type descending: bool, optional
394
433
  """
395
434
  # handle default dimension or check that it is the most minor dim
396
435
  _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
397
436
  core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
398
- # iteratively run bitonic merge-sort steps
399
- n_dims: core.constexpr = _log2(x.shape[_dim])
400
- for i in core.static_range(1, n_dims + 1):
401
- x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims)
437
+
438
+ log_n: core.constexpr = _log2(x.shape[_dim])
439
+ log_k: core.constexpr = log_n if k is None else _log2(k)
440
+
441
+ n_dims: core.constexpr = _log2(x.numel)
442
+
443
+ # reshape to hypercube:
444
+ h = core.reshape(x, [2] * n_dims)
445
+
446
+ # run first log_k bitonic sort iterations:
447
+ for i in core.static_range(1, log_k + 1):
448
+ h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
449
+
450
+ # select top k elements using bitonic top-k
451
+ # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
452
+ for i in core.static_range(log_k + 1, log_n + 1):
453
+ h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
454
+ h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
455
+
456
+ # reshape back:
457
+ x = core.reshape(h, x.shape[:-1] + [2**log_k])
402
458
  return x
403
459
 
404
460
 
405
- # flip
461
+ @jit
462
+ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
463
+ return sort_impl(x, dim=dim, descending=descending)
464
+
465
+
466
+ @jit
467
+ def topk(x, k: core.constexpr, dim: core.constexpr = None):
468
+ return sort_impl(x, k=k, dim=dim, descending=True)
469
+
470
+
471
+ @jit
472
+ def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
473
+ # handle default dimension or check that it is the most minor dim
474
+ _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
475
+ core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
476
+ n_dims: core.constexpr = _log2(x.shape[-1])
477
+ return _bitonic_merge(x, n_dims, descending, n_dims)
406
478
 
407
479
 
480
+ @constexpr_function
408
481
  def _get_flip_dim(dim, shape):
409
- dim = core._unwrap_if_constexpr(dim)
410
- shape = core._unwrap_if_constexpr(shape)
411
482
  if dim is None:
412
483
  dim = len(shape) - 1
413
- assert dim == len(shape) - 1, "Currently only support flipping the last dimension"
414
- return core.constexpr(dim)
484
+ if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
485
+ dim += len(shape)
486
+ return dim
415
487
 
416
488
 
417
489
  @core._tensor_member_fn
@@ -422,26 +494,19 @@ def flip(x, dim=None):
422
494
 
423
495
  :param x: the first input tensor
424
496
  :type x: Block
425
- :param dim: the dimension to flip along (currently only final dimension supported)
497
+ :param dim: the dimension to flip along
426
498
  :type dim: int
427
499
  """
428
- core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
429
- core.static_assert(_is_power_of_two(x.numel))
430
- # reshape the tensor to have all dimensions be 2.
431
- # TODO: We shouldn't have to change the dimensions not sorted.
432
- steps: core.constexpr = _log2(x.numel)
433
- start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
500
+ core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
501
+ _dim: core.constexpr = _get_flip_dim(dim, x.shape)
502
+ core.static_assert(_is_power_of_two(x.shape[_dim]))
503
+ steps: core.constexpr = _log2(x.shape[_dim])
434
504
 
505
+ # reshape the swap dimension to (2, 2, ..., 2)
435
506
  idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
436
- y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
437
- y = core.expand_dims(y, start)
438
- flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
439
- for i in core.static_range(start, steps):
440
- flip2 = flip
441
- for j in core.static_range(0, steps + 1):
442
- if j != i and j != i + 1:
443
- flip2 = core.expand_dims(flip2, j)
444
- y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
507
+ y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
508
+ for i in core.static_range(steps):
509
+ y = y ^ xor_sum(y, _dim + i, True)
445
510
  x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
446
511
  return x
447
512
 
@@ -0,0 +1,54 @@
1
+ from triton.runtime import driver
2
+ from triton.runtime.jit import constexpr_function
3
+
4
+ __all__ = ["current_target"]
5
+
6
+
7
+ def current_target():
8
+ try:
9
+ active_driver = driver.active
10
+ except RuntimeError:
11
+ # If there is no active driver, return None
12
+ return None
13
+ return active_driver.get_current_target()
14
+
15
+
16
+ current_target.__triton_builtin__ = True
17
+
18
+
19
+ @constexpr_function
20
+ def is_cuda():
21
+ target = current_target()
22
+ return target is not None and target.backend == "cuda"
23
+
24
+
25
+ @constexpr_function
26
+ def cuda_capability_geq(major, minor=0):
27
+ """
28
+ Determines whether we have compute capability >= (major, minor) and
29
+ returns this as a constexpr boolean. This can be used for guarding
30
+ inline asm implementations that require a certain compute capability.
31
+ """
32
+ target = current_target()
33
+ if target is None or target.backend != "cuda":
34
+ return False
35
+ assert isinstance(target.arch, int)
36
+ return target.arch >= major * 10 + minor
37
+
38
+
39
+ @constexpr_function
40
+ def is_hip():
41
+ target = current_target()
42
+ return target is not None and target.backend == "hip"
43
+
44
+
45
+ @constexpr_function
46
+ def is_hip_cdna3():
47
+ target = current_target()
48
+ return target is not None and target.arch == "gfx942"
49
+
50
+
51
+ @constexpr_function
52
+ def is_hip_cdna4():
53
+ target = current_target()
54
+ return target is not None and target.arch == "gfx950"
@@ -1,4 +1,5 @@
1
1
  from typing import Optional, Protocol
2
+ from contextvars import ContextVar
2
3
 
3
4
 
4
5
  class Buffer(Protocol):
@@ -20,7 +21,7 @@ class NullAllocator:
20
21
  "Use triton.set_allocator to specify an allocator.")
21
22
 
22
23
 
23
- _allocator: Allocator = NullAllocator()
24
+ _allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
24
25
 
25
26
 
26
27
  def set_allocator(allocator: Allocator):
@@ -28,5 +29,16 @@ def set_allocator(allocator: Allocator):
28
29
  The allocator function is called during kernel launch for kernels that
29
30
  require additional global memory workspace.
30
31
  """
31
- global _allocator
32
- _allocator = allocator
32
+ _allocator.set(allocator)
33
+
34
+
35
+ _profile_allocator: Allocator = ContextVar("_allocator", default=NullAllocator())
36
+
37
+
38
+ def set_profile_allocator(allocator: Optional[Allocator]):
39
+ """
40
+ The profile allocator function is called before kernel launch for kernels
41
+ that require additional global memory workspace.
42
+ """
43
+ global _profile_allocator
44
+ _profile_allocator.set(allocator)
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+ from typing import Callable, Optional
3
+ from concurrent.futures import Executor, as_completed, Future
4
+ from contextvars import ContextVar
5
+
6
+ active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
7
+
8
+
9
+ class FutureKernel:
10
+
11
+ def __init__(self, finalize_compile: Callable, future: Future):
12
+ self.finalize_compile = finalize_compile
13
+ self.kernel = None
14
+ self.future = future
15
+
16
+ def result(self):
17
+ if self.kernel is not None:
18
+ return self.kernel
19
+
20
+ kernel = self.future.result()
21
+ self.finalize_compile(kernel)
22
+ self.kernel = kernel
23
+ return kernel
24
+
25
+
26
+ class AsyncCompileMode:
27
+
28
+ def __init__(self, executor: Executor):
29
+ self.executor = executor
30
+ self.raw_futures = []
31
+ self.future_kernels = {}
32
+
33
+ def submit(self, key, compile_fn, finalize_fn):
34
+ future = self.future_kernels.get(key)
35
+ if future is not None:
36
+ return future
37
+
38
+ future = self.executor.submit(compile_fn)
39
+ future._key = key
40
+ self.raw_futures.append(future)
41
+ future_kernel = FutureKernel(finalize_fn, future)
42
+ self.future_kernels[key] = future_kernel
43
+ return future_kernel
44
+
45
+ def __enter__(self):
46
+ if active_mode.get() is not None:
47
+ raise RuntimeError("Another AsyncCompileMode is already active")
48
+ active_mode.set(self)
49
+ return self
50
+
51
+ def __exit__(self, exc_type, exc_value, traceback):
52
+ # Finalize any outstanding compiles
53
+ for future in as_completed(self.raw_futures):
54
+ self.future_kernels[future._key].result()
55
+ active_mode.set(None)