triton-windows 3.3.1.post19__cp39-cp39-win_amd64.whl → 3.4.0.post20__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (166) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +4 -1
  3. triton/_filecheck.py +87 -0
  4. triton/_internal_testing.py +26 -15
  5. triton/_utils.py +110 -21
  6. triton/backends/__init__.py +20 -23
  7. triton/backends/amd/__init__.py +0 -0
  8. triton/backends/amd/compiler.py +112 -78
  9. triton/backends/amd/driver.c +5 -2
  10. triton/backends/amd/driver.py +149 -47
  11. triton/backends/compiler.py +7 -21
  12. triton/backends/nvidia/bin/ptxas.exe +0 -0
  13. triton/backends/nvidia/compiler.py +92 -93
  14. triton/backends/nvidia/driver.c +90 -98
  15. triton/backends/nvidia/driver.py +303 -128
  16. triton/compiler/code_generator.py +212 -111
  17. triton/compiler/compiler.py +110 -25
  18. triton/experimental/__init__.py +0 -0
  19. triton/experimental/gluon/__init__.py +4 -0
  20. triton/experimental/gluon/_compiler.py +0 -0
  21. triton/experimental/gluon/_runtime.py +99 -0
  22. triton/experimental/gluon/language/__init__.py +18 -0
  23. triton/experimental/gluon/language/_core.py +312 -0
  24. triton/experimental/gluon/language/_layouts.py +230 -0
  25. triton/experimental/gluon/language/_math.py +12 -0
  26. triton/experimental/gluon/language/_semantic.py +287 -0
  27. triton/experimental/gluon/language/_standard.py +47 -0
  28. triton/experimental/gluon/language/nvidia/__init__.py +4 -0
  29. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +202 -0
  30. triton/experimental/gluon/language/nvidia/blackwell/tma.py +32 -0
  31. triton/experimental/gluon/language/nvidia/hopper/__init__.py +11 -0
  32. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +51 -0
  33. triton/experimental/gluon/language/nvidia/hopper/tma.py +96 -0
  34. triton/experimental/gluon/nvidia/__init__.py +4 -0
  35. triton/experimental/gluon/nvidia/blackwell.py +3 -0
  36. triton/experimental/gluon/nvidia/hopper.py +40 -0
  37. triton/knobs.py +481 -0
  38. triton/language/__init__.py +39 -14
  39. triton/language/core.py +794 -537
  40. triton/language/extra/cuda/__init__.py +10 -7
  41. triton/language/extra/cuda/gdc.py +42 -0
  42. triton/language/extra/cuda/libdevice.py +394 -394
  43. triton/language/extra/cuda/utils.py +21 -21
  44. triton/language/extra/hip/libdevice.py +113 -104
  45. triton/language/math.py +65 -66
  46. triton/language/random.py +12 -2
  47. triton/language/semantic.py +1706 -1770
  48. triton/language/standard.py +116 -51
  49. triton/runtime/autotuner.py +117 -59
  50. triton/runtime/build.py +76 -12
  51. triton/runtime/cache.py +18 -47
  52. triton/runtime/driver.py +32 -29
  53. triton/runtime/interpreter.py +72 -35
  54. triton/runtime/jit.py +146 -110
  55. triton/testing.py +16 -12
  56. triton/tools/disasm.py +3 -4
  57. triton/tools/tensor_descriptor.py +36 -0
  58. triton/windows_utils.py +14 -6
  59. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/METADATA +7 -2
  60. triton_windows-3.4.0.post20.dist-info/RECORD +186 -0
  61. triton_windows-3.4.0.post20.dist-info/entry_points.txt +3 -0
  62. triton_windows-3.4.0.post20.dist-info/licenses/LICENSE +23 -0
  63. triton_windows-3.4.0.post20.dist-info/top_level.txt +1 -0
  64. triton/backends/amd/include/hip/amd_detail/amd_channel_descriptor.h +0 -358
  65. triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +0 -1010
  66. triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +0 -1638
  67. triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +0 -1814
  68. triton/backends/amd/include/hip/amd_detail/amd_hip_bfloat16.h +0 -293
  69. triton/backends/amd/include/hip/amd_detail/amd_hip_common.h +0 -32
  70. triton/backends/amd/include/hip/amd_detail/amd_hip_complex.h +0 -174
  71. triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +0 -835
  72. triton/backends/amd/include/hip/amd_detail/amd_hip_fp16.h +0 -1809
  73. triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +0 -1391
  74. triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +0 -108
  75. triton/backends/amd/include/hip/amd_detail/amd_hip_math_constants.h +0 -124
  76. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime.h +0 -405
  77. triton/backends/amd/include/hip/amd_detail/amd_hip_runtime_pt_api.h +0 -196
  78. triton/backends/amd/include/hip/amd_detail/amd_hip_unsafe_atomics.h +0 -565
  79. triton/backends/amd/include/hip/amd_detail/amd_hip_vector_types.h +0 -2226
  80. triton/backends/amd/include/hip/amd_detail/amd_math_functions.h +0 -104
  81. triton/backends/amd/include/hip/amd_detail/amd_surface_functions.h +0 -244
  82. triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +0 -538
  83. triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +0 -288
  84. triton/backends/amd/include/hip/amd_detail/concepts.hpp +0 -30
  85. triton/backends/amd/include/hip/amd_detail/device_library_decls.h +0 -133
  86. triton/backends/amd/include/hip/amd_detail/functional_grid_launch.hpp +0 -218
  87. triton/backends/amd/include/hip/amd_detail/grid_launch.h +0 -67
  88. triton/backends/amd/include/hip/amd_detail/grid_launch.hpp +0 -50
  89. triton/backends/amd/include/hip/amd_detail/grid_launch_GGL.hpp +0 -26
  90. triton/backends/amd/include/hip/amd_detail/helpers.hpp +0 -137
  91. triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +0 -1446
  92. triton/backends/amd/include/hip/amd_detail/hip_assert.h +0 -101
  93. triton/backends/amd/include/hip/amd_detail/hip_cooperative_groups_helper.h +0 -242
  94. triton/backends/amd/include/hip/amd_detail/hip_fp16_gcc.h +0 -254
  95. triton/backends/amd/include/hip/amd_detail/hip_fp16_math_fwd.h +0 -96
  96. triton/backends/amd/include/hip/amd_detail/hip_ldg.h +0 -100
  97. triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +0 -10570
  98. triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +0 -78
  99. triton/backends/amd/include/hip/amd_detail/host_defines.h +0 -184
  100. triton/backends/amd/include/hip/amd_detail/hsa_helpers.hpp +0 -102
  101. triton/backends/amd/include/hip/amd_detail/macro_based_grid_launch.hpp +0 -798
  102. triton/backends/amd/include/hip/amd_detail/math_fwd.h +0 -698
  103. triton/backends/amd/include/hip/amd_detail/ockl_image.h +0 -177
  104. triton/backends/amd/include/hip/amd_detail/program_state.hpp +0 -107
  105. triton/backends/amd/include/hip/amd_detail/texture_fetch_functions.h +0 -491
  106. triton/backends/amd/include/hip/amd_detail/texture_indirect_functions.h +0 -478
  107. triton/backends/amd/include/hip/channel_descriptor.h +0 -39
  108. triton/backends/amd/include/hip/device_functions.h +0 -38
  109. triton/backends/amd/include/hip/driver_types.h +0 -468
  110. triton/backends/amd/include/hip/hip_bf16.h +0 -36
  111. triton/backends/amd/include/hip/hip_bfloat16.h +0 -44
  112. triton/backends/amd/include/hip/hip_common.h +0 -100
  113. triton/backends/amd/include/hip/hip_complex.h +0 -38
  114. triton/backends/amd/include/hip/hip_cooperative_groups.h +0 -46
  115. triton/backends/amd/include/hip/hip_deprecated.h +0 -95
  116. triton/backends/amd/include/hip/hip_ext.h +0 -161
  117. triton/backends/amd/include/hip/hip_fp16.h +0 -36
  118. triton/backends/amd/include/hip/hip_fp8.h +0 -33
  119. triton/backends/amd/include/hip/hip_gl_interop.h +0 -32
  120. triton/backends/amd/include/hip/hip_hcc.h +0 -24
  121. triton/backends/amd/include/hip/hip_math_constants.h +0 -36
  122. triton/backends/amd/include/hip/hip_profile.h +0 -27
  123. triton/backends/amd/include/hip/hip_runtime.h +0 -75
  124. triton/backends/amd/include/hip/hip_runtime_api.h +0 -9261
  125. triton/backends/amd/include/hip/hip_texture_types.h +0 -29
  126. triton/backends/amd/include/hip/hip_vector_types.h +0 -41
  127. triton/backends/amd/include/hip/hip_version.h +0 -17
  128. triton/backends/amd/include/hip/hiprtc.h +0 -421
  129. triton/backends/amd/include/hip/library_types.h +0 -78
  130. triton/backends/amd/include/hip/math_functions.h +0 -42
  131. triton/backends/amd/include/hip/surface_types.h +0 -63
  132. triton/backends/amd/include/hip/texture_types.h +0 -194
  133. triton/backends/amd/include/hsa/Brig.h +0 -1131
  134. triton/backends/amd/include/hsa/amd_hsa_common.h +0 -91
  135. triton/backends/amd/include/hsa/amd_hsa_elf.h +0 -462
  136. triton/backends/amd/include/hsa/amd_hsa_kernel_code.h +0 -269
  137. triton/backends/amd/include/hsa/amd_hsa_queue.h +0 -109
  138. triton/backends/amd/include/hsa/amd_hsa_signal.h +0 -80
  139. triton/backends/amd/include/hsa/hsa.h +0 -5738
  140. triton/backends/amd/include/hsa/hsa_amd_tool.h +0 -91
  141. triton/backends/amd/include/hsa/hsa_api_trace.h +0 -579
  142. triton/backends/amd/include/hsa/hsa_api_trace_version.h +0 -68
  143. triton/backends/amd/include/hsa/hsa_ext_amd.h +0 -3146
  144. triton/backends/amd/include/hsa/hsa_ext_finalize.h +0 -531
  145. triton/backends/amd/include/hsa/hsa_ext_image.h +0 -1454
  146. triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +0 -488
  147. triton/backends/amd/include/hsa/hsa_ven_amd_loader.h +0 -667
  148. triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +0 -416
  149. triton/backends/amd/include/roctracer/ext/prof_protocol.h +0 -107
  150. triton/backends/amd/include/roctracer/hip_ostream_ops.h +0 -4515
  151. triton/backends/amd/include/roctracer/hsa_ostream_ops.h +0 -1727
  152. triton/backends/amd/include/roctracer/hsa_prof_str.h +0 -3059
  153. triton/backends/amd/include/roctracer/roctracer.h +0 -779
  154. triton/backends/amd/include/roctracer/roctracer_ext.h +0 -81
  155. triton/backends/amd/include/roctracer/roctracer_hcc.h +0 -24
  156. triton/backends/amd/include/roctracer/roctracer_hip.h +0 -37
  157. triton/backends/amd/include/roctracer/roctracer_hsa.h +0 -112
  158. triton/backends/amd/include/roctracer/roctracer_plugin.h +0 -137
  159. triton/backends/amd/include/roctracer/roctracer_roctx.h +0 -67
  160. triton/backends/amd/include/roctracer/roctx.h +0 -229
  161. triton/language/_utils.py +0 -21
  162. triton/language/extra/cuda/_experimental_tma.py +0 -106
  163. triton/tools/experimental_descriptor.py +0 -32
  164. triton_windows-3.3.1.post19.dist-info/RECORD +0 -260
  165. triton_windows-3.3.1.post19.dist-info/top_level.txt +0 -14
  166. {triton_windows-3.3.1.post19.dist-info → triton_windows-3.4.0.post20.dist-info}/WHEEL +0 -0
@@ -9,7 +9,7 @@ from . import math
9
9
 
10
10
  def _log2(i: core.constexpr):
11
11
  log2 = 0
12
- n = i.value
12
+ n = core.constexpr(i).value
13
13
  while n > 1:
14
14
  n >>= 1
15
15
  log2 += 1
@@ -50,10 +50,14 @@ def sigmoid(x):
50
50
  @core._tensor_member_fn
51
51
  @jit
52
52
  @math._add_math_1arg_docstr("softmax")
53
- def softmax(x, ieee_rounding=False):
54
- z = x - max(x, 0)
53
+ def softmax(x, dim=None, keep_dims=False, ieee_rounding=False):
54
+ if dim is None:
55
+ _dim: core.constexpr = 0
56
+ else:
57
+ _dim: core.constexpr = dim
58
+ z = x - max(x, _dim, keep_dims=keep_dims)
55
59
  num = math.exp(z)
56
- den = sum(num, 0)
60
+ den = sum(num, _dim, keep_dims=keep_dims)
57
61
  return math.fdiv(num, den, ieee_rounding)
58
62
 
59
63
 
@@ -302,15 +306,37 @@ def xor_sum(input, axis=None, keep_dims=False):
302
306
  return core.reduce(input, axis, _xor_combine, keep_dims=keep_dims)
303
307
 
304
308
 
309
+ # or reduction
310
+
311
+
312
+ @jit
313
+ def _or_combine(x, y):
314
+ return x | y
315
+
316
+
317
+ @core._tensor_member_fn
318
+ @jit
319
+ @core._add_reduction_docstr("reduce_of")
320
+ def reduce_or(input, axis, keep_dims=False):
321
+ core.static_assert(input.type.scalar.is_int(), "reduce_of only supported for integers")
322
+ return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
323
+
324
+
305
325
  # cumsum
306
326
 
307
327
 
308
328
  @core._tensor_member_fn
309
329
  @jit
310
- @core._add_scan_docstr("cumsum")
311
- def cumsum(input, axis=0, reverse=False):
330
+ @core._add_scan_docstr("cumsum", dtype_arg="dtype")
331
+ def cumsum(input, axis=0, reverse=False, dtype: core.constexpr = None):
312
332
  # todo rename this to a generic function name
333
+
313
334
  input = core._promote_bfloat16_to_float32(input)
335
+ out_dtype: core.constexpr = _pick_sum_dtype(input.dtype, dtype)
336
+
337
+ if out_dtype is not None:
338
+ input = input.to(out_dtype)
339
+
314
340
  return core.associative_scan(input, axis, _sum_combine, reverse)
315
341
 
316
342
 
@@ -335,53 +361,63 @@ def cumprod(input, axis=0, reverse=False):
335
361
 
336
362
 
337
363
  @jit
338
- def _compare_and_swap(x, flip, i: core.constexpr, n_dims: core.constexpr):
339
- n_outer: core.constexpr = x.numel >> n_dims
340
- shape: core.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
341
- y = core.reshape(x, shape)
342
- # slice left/right with 'stride' 2**(n_dims - i - 1)
343
- mask = core.arange(0, 2)[None, :, None]
344
- left = core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
345
- right = core.broadcast_to(sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
346
- left = core.reshape(left, x.shape)
347
- right = core.reshape(right, x.shape)
348
- # actual compare-and-swap
364
+ def _indicator(n_dims: core.constexpr, j: core.constexpr):
365
+ ar = core.arange(0, 2)
366
+ ar = core.reshape(ar, [1] * (n_dims - j - 1) + [2] + [1] * j)
367
+ return ar
368
+
369
+
370
+ @jit
371
+ def _compare_and_swap(x, flip, i: core.constexpr):
372
+ # compare-and-swap on the ith *innermost* dimension
373
+ n_dims: core.constexpr = _log2(x.numel)
374
+
375
+ # flip along middle dimension (the bitwise XORs will be optimised away):
349
376
  idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
350
- ileft = left.to(idtype, bitcast=True)
351
- iright = right.to(idtype, bitcast=True)
352
377
  ix = x.to(idtype, bitcast=True)
353
- ret = ix ^ core.where((left > right) != flip, ileft ^ iright, zeros_like(ix))
354
- return ret.to(x.dtype, bitcast=True)
378
+ iy = ix ^ xor_sum(ix, n_dims - 1 - i, True)
379
+ y = iy.to(x.dtype, bitcast=True)
380
+
381
+ # determines whether we are in the right (rather than left) position along the axis:
382
+ is_right = _indicator(n_dims, i)
383
+
384
+ # conditional swap:
385
+ ret = core.where((x > y) != (flip ^ is_right), y, x)
386
+ return ret
355
387
 
356
388
 
357
389
  @jit
358
- def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
390
+ def _bitonic_merge_hypercube(x, stage: core.constexpr, order: core.constexpr):
359
391
  '''
360
392
  order_type 0 == ascending
361
393
  order_type 1 == descending
362
394
  order_type 2 == alternating
363
395
  '''
364
- n_outer: core.constexpr = x.numel >> n_dims
365
- core.static_assert(stage <= n_dims)
366
396
  # flip denotes whether to re-arrange sub-sequences of elements in ascending or
367
397
  # descending order.
368
398
  # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
369
399
  # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
370
400
  # a stride of 2) at this stage
371
401
  if order == 2:
372
- shape: core.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
373
- flip = core.reshape(core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape)
402
+ flip = _indicator(_log2(x.numel), stage)
374
403
  else:
375
404
  flip = order
376
405
  # perform `stage` rounds of `compare-and-swap`
377
406
  for i in core.static_range(stage):
378
- x = _compare_and_swap(x, flip, i + (n_dims - stage), n_dims)
407
+ x = _compare_and_swap(x, flip, stage - 1 - i)
379
408
  return x
380
409
 
381
410
 
382
- @core._tensor_member_fn
383
411
  @jit
384
- def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
412
+ def _bitonic_merge(x, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr):
413
+ h = core.reshape(x, [2] * _log2(x.numel))
414
+ h = _bitonic_merge_hypercube(h, stage, order)
415
+ x = core.reshape(h, x.shape)
416
+ return x
417
+
418
+
419
+ @jit
420
+ def sort_impl(x, k: core.constexpr = None, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
385
421
  """
386
422
  Sorts a tensor along a specified dimension.
387
423
 
@@ -389,20 +425,55 @@ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTE
389
425
  :type x: Tensor
390
426
  :param dim: The dimension along which to sort the tensor. If None, the tensor is sorted along the last dimension. Currently, only sorting along the last dimension is supported.
391
427
  :type dim: int, optional
428
+ :param k: the number of top elements to select. If none, assume k = x.shape[dim]
429
+ :type k: int, optional
392
430
  :param descending: If set to True, the tensor is sorted in descending order. If set to False, the tensor is sorted in ascending order.
393
431
  :type descending: bool, optional
394
432
  """
395
433
  # handle default dimension or check that it is the most minor dim
396
434
  _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
397
435
  core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
398
- # iteratively run bitonic merge-sort steps
399
- n_dims: core.constexpr = _log2(x.shape[_dim])
400
- for i in core.static_range(1, n_dims + 1):
401
- x = _bitonic_merge(x, i, 2 if i < n_dims else descending, n_dims)
436
+
437
+ log_n: core.constexpr = _log2(x.shape[_dim])
438
+ log_k: core.constexpr = log_n if k is None else _log2(k)
439
+
440
+ n_dims: core.constexpr = _log2(x.numel)
441
+
442
+ # reshape to hypercube:
443
+ h = core.reshape(x, [2] * n_dims)
444
+
445
+ # run first log_k bitonic sort iterations:
446
+ for i in core.static_range(1, log_k + 1):
447
+ h = _bitonic_merge_hypercube(h, i, 2 if i < log_n else descending)
448
+
449
+ # select top k elements using bitonic top-k
450
+ # https://www.doc.ic.ac.uk/~hlgr/pdfs/MassivelyParallelTopK.pdf
451
+ for i in core.static_range(log_k + 1, log_n + 1):
452
+ h = max(h, axis=(_log2(h.numel) - 1 - log_k)) if descending else min(h, axis=(_log2(h.numel) - 1 - log_k))
453
+ h = _bitonic_merge_hypercube(h, log_k, 2 if i < log_n else descending)
454
+
455
+ # reshape back:
456
+ x = core.reshape(h, x.shape[:-1] + [2**log_k])
402
457
  return x
403
458
 
404
459
 
405
- # flip
460
+ @jit
461
+ def sort(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
462
+ return sort_impl(x, dim=dim, descending=descending)
463
+
464
+
465
+ @jit
466
+ def topk(x, k: core.constexpr, dim: core.constexpr = None):
467
+ return sort_impl(x, k=k, dim=dim, descending=True)
468
+
469
+
470
+ @jit
471
+ def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = core.CONSTEXPR_0):
472
+ # handle default dimension or check that it is the most minor dim
473
+ _dim: core.constexpr = len(x.shape) - 1 if dim is None else dim
474
+ core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
475
+ n_dims: core.constexpr = _log2(x.shape[-1])
476
+ return _bitonic_merge(x, n_dims, descending, n_dims)
406
477
 
407
478
 
408
479
  def _get_flip_dim(dim, shape):
@@ -410,7 +481,8 @@ def _get_flip_dim(dim, shape):
410
481
  shape = core._unwrap_if_constexpr(shape)
411
482
  if dim is None:
412
483
  dim = len(shape) - 1
413
- assert dim == len(shape) - 1, "Currently only support flipping the last dimension"
484
+ if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
485
+ dim += len(shape)
414
486
  return core.constexpr(dim)
415
487
 
416
488
 
@@ -422,26 +494,19 @@ def flip(x, dim=None):
422
494
 
423
495
  :param x: the first input tensor
424
496
  :type x: Block
425
- :param dim: the dimension to flip along (currently only final dimension supported)
497
+ :param dim: the dimension to flip along
426
498
  :type dim: int
427
499
  """
428
- core.static_assert(_is_power_of_two(x.shape[_get_flip_dim(dim, x.shape)]))
429
- core.static_assert(_is_power_of_two(x.numel))
430
- # reshape the tensor to have all dimensions be 2.
431
- # TODO: We shouldn't have to change the dimensions not sorted.
432
- steps: core.constexpr = _log2(x.numel)
433
- start: core.constexpr = _log2(x.numel) - _log2(x.shape[_get_flip_dim(dim, x.shape)])
500
+ core.static_assert(-len(x.shape) <= dim and dim < len(x.shape))
501
+ _dim: core.constexpr = _get_flip_dim(dim, x.shape)
502
+ core.static_assert(_is_power_of_two(x.shape[_dim]))
503
+ steps: core.constexpr = _log2(x.shape[_dim])
434
504
 
505
+ # reshape the swap dimension to (2, 2, ..., 2)
435
506
  idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
436
- y = core.reshape(x.to(idtype, bitcast=True), [2] * steps)
437
- y = core.expand_dims(y, start)
438
- flip = (core.arange(0, 2)[:, None] == 1 - core.arange(0, 2))
439
- for i in core.static_range(start, steps):
440
- flip2 = flip
441
- for j in core.static_range(0, steps + 1):
442
- if j != i and j != i + 1:
443
- flip2 = core.expand_dims(flip2, j)
444
- y = sum(y * flip2, i + 1, keep_dims=True, dtype=y.dtype)
507
+ y = core.reshape(x.to(idtype, bitcast=True), x.shape[:_dim] + [2] * steps + x.shape[_dim + 1:])
508
+ for i in core.static_range(steps):
509
+ y = y ^ xor_sum(y, _dim + i, True)
445
510
  x = core.reshape(y, x.shape).to(x.dtype, bitcast=True)
446
511
  return x
447
512
 
@@ -1,11 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import builtins
4
- import os
5
4
  import time
6
5
  import inspect
6
+ import hashlib
7
+ import json
8
+ from functools import cached_property
7
9
  from typing import Dict, Tuple, List, Optional
8
10
 
11
+ from .. import knobs
9
12
  from .jit import KernelInterface
10
13
  from .errors import OutOfResources, PTXASError
11
14
  from .driver import driver
@@ -13,22 +16,9 @@ from .driver import driver
13
16
 
14
17
  class Autotuner(KernelInterface):
15
18
 
16
- def __init__(
17
- self,
18
- fn,
19
- arg_names,
20
- configs,
21
- key,
22
- reset_to_zero,
23
- restore_value,
24
- pre_hook=None,
25
- post_hook=None,
26
- prune_configs_by: Optional[Dict] = None,
27
- warmup=None,
28
- rep=None,
29
- use_cuda_graph=False,
30
- do_bench=None,
31
- ):
19
+ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pre_hook=None, post_hook=None,
20
+ prune_configs_by: Optional[Dict] = None, warmup=None, rep=None, use_cuda_graph=False, do_bench=None,
21
+ cache_results=False):
32
22
  """
33
23
  :param prune_configs_by: a dict of functions that are used to prune configs, fields:
34
24
  'perf_model': performance model used to predicate running time with different configs, returns running time
@@ -36,15 +26,13 @@ class Autotuner(KernelInterface):
36
26
  'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
37
27
  """
38
28
  if not configs:
39
- self.configs = [
40
- Config({}, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
41
- reg_dec_producer=0, reg_inc_consumer=0)
42
- ]
29
+ self.configs = [Config({}, num_warps=4, num_stages=3, num_ctas=1)]
43
30
  else:
44
31
  self.configs = configs
45
32
  self.keys = key
46
33
  self.cache: Dict[Tuple, Config] = {}
47
34
  self.arg_names = arg_names
35
+ self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret)
48
36
 
49
37
  # Reset to zero or restore values
50
38
  self.reset_to_zero = []
@@ -97,6 +85,7 @@ class Autotuner(KernelInterface):
97
85
  while not inspect.isfunction(self.base_fn):
98
86
  self.base_fn = self.base_fn.fn
99
87
 
88
+ self._do_bench = do_bench
100
89
  self.num_warmups = warmup
101
90
  self.num_reps = rep
102
91
  self.use_cuda_graph = use_cuda_graph
@@ -110,7 +99,7 @@ class Autotuner(KernelInterface):
110
99
  stacklevel=1)
111
100
  if use_cuda_graph:
112
101
  from ..testing import do_bench_cudagraph
113
- self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
102
+ self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
114
103
  kernel_call,
115
104
  rep=rep if rep is not None else 100,
116
105
  quantiles=quantiles,
@@ -118,7 +107,7 @@ class Autotuner(KernelInterface):
118
107
  return
119
108
 
120
109
  import triton.testing
121
- self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
110
+ self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
122
111
  kernel_call,
123
112
  warmup=warmup if warmup is not None else 25,
124
113
  rep=rep if rep is not None else 100,
@@ -126,15 +115,16 @@ class Autotuner(KernelInterface):
126
115
  )
127
116
  return
128
117
 
129
- if do_bench is None:
130
- self.do_bench = driver.active.get_benchmarker()
131
- else:
132
- self.do_bench = do_bench
118
+ @cached_property
119
+ def do_bench(self):
120
+ if self._do_bench is None:
121
+ return driver.active.get_benchmarker()
122
+ return self._do_bench
133
123
 
134
124
  def _bench(self, *args, config, **meta):
135
125
  from ..compiler.errors import CompileTimeAssertionFailure
136
126
 
137
- verbose = os.environ.get("TRITON_PRINT_AUTOTUNING", None) == "1"
127
+ verbose = knobs.autotuning.print
138
128
  if verbose:
139
129
  print(f"Autotuning kernel {self.base_fn.__name__} with config {config}")
140
130
 
@@ -173,6 +163,51 @@ class Autotuner(KernelInterface):
173
163
  print(f"Autotuning failed with {e}")
174
164
  return [float("inf"), float("inf"), float("inf")]
175
165
 
166
+ def check_disk_cache(self, tuning_key, configs, bench_fn):
167
+ # We can't serialize prehooks, so just give up and run the benchmarks.
168
+ if not tuning_key or any(cfg.pre_hook for cfg in configs):
169
+ bench_fn()
170
+ return False
171
+
172
+ from triton._C.libtriton import get_cache_invalidating_env_vars
173
+ from triton.compiler.compiler import make_backend, triton_key
174
+ from triton.runtime.cache import get_cache_manager
175
+ from triton.runtime.jit import JITFunction
176
+
177
+ fn = self.fn
178
+ while not isinstance(fn, JITFunction):
179
+ fn = fn.fn
180
+
181
+ env_vars = get_cache_invalidating_env_vars()
182
+ cache_key = [
183
+ triton_key(),
184
+ make_backend(driver.active.get_current_target()).hash(),
185
+ fn.cache_key,
186
+ str(sorted(env_vars.items())),
187
+ str(tuning_key),
188
+ ] + [str(c) for c in configs]
189
+ cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest()
190
+ cache = get_cache_manager(cache_key)
191
+ file_name = f"{fn.__name__[:150]}.autotune.json"
192
+ path = cache.get_file(file_name)
193
+ if path:
194
+ with open(path, "r") as cached_configs:
195
+ timings = json.load(cached_configs)["configs_timings"]
196
+ timings = {Config(**config): timing for config, timing in timings}
197
+ self.cache[tuning_key] = builtins.min(timings, key=timings.get)
198
+ self.configs_timings = timings
199
+ return True
200
+
201
+ bench_fn()
202
+ cache.put(
203
+ json.dumps({
204
+ "key":
205
+ tuning_key,
206
+ "configs_timings":
207
+ [(config.__dict__, timings) for config, timings in self.configs_timings.items() if not config.pre_hook],
208
+ }), file_name, binary=False)
209
+ return False
210
+
176
211
  def run(self, *args, **kwargs):
177
212
  self.nargs = dict(zip(self.arg_names, args))
178
213
  used_cached_result = True
@@ -185,24 +220,31 @@ class Autotuner(KernelInterface):
185
220
  key.append(str(arg.dtype))
186
221
  key = tuple(key)
187
222
  if key not in self.cache:
188
- # prune configs
189
223
  used_cached_result = False
190
224
  pruned_configs = self.prune_configs(kwargs)
191
- bench_start = time.time()
192
- timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
193
- bench_end = time.time()
194
- self.bench_time = bench_end - bench_start
195
- self.cache[key] = builtins.min(timings, key=timings.get)
196
- full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
197
- self.pre_hook(full_nargs, reset_only=True)
198
- self.configs_timings = timings
225
+
226
+ def benchmark():
227
+ bench_start = time.perf_counter()
228
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
229
+ bench_end = time.perf_counter()
230
+ self.bench_time = bench_end - bench_start
231
+ self.cache[key] = builtins.min(timings, key=timings.get)
232
+ full_nargs = {**self.nargs, **kwargs, **self.cache[key].all_kwargs()}
233
+ self.pre_hook(full_nargs, reset_only=True)
234
+ self.configs_timings = timings
235
+
236
+ if self.cache_results:
237
+ used_cached_result = self.check_disk_cache(key, pruned_configs, benchmark)
238
+ else:
239
+ benchmark()
240
+
199
241
  config = self.cache[key]
200
242
  else:
201
243
  config = self.configs[0]
202
244
  self.best_config = config
203
- if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
204
- print(f"Triton autotuning for function {self.base_fn.__name__} finished after "
205
- f"{self.bench_time:.2f}s; best config selected: {self.best_config};")
245
+ if knobs.autotuning.print and not used_cached_result:
246
+ print(f"Triton autotuning for function {self.base_fn.__name__},\nwith key as {key},\n"
247
+ f"finished after {self.bench_time:.2f}s,\nbest config selected: {self.best_config};")
206
248
  if config.pre_hook is not None:
207
249
  full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
208
250
  config.pre_hook(full_nargs)
@@ -241,11 +283,11 @@ class Autotuner(KernelInterface):
241
283
  def warmup(self, *args, **kwargs):
242
284
  self.nargs = dict(zip(self.arg_names, args))
243
285
  ret = []
244
- for config in self.prune_configs(kwargs):
286
+ for autotune_config in self.prune_configs(kwargs):
245
287
  ret.append(self.fn.warmup(
246
288
  *args,
247
289
  **kwargs,
248
- **config.all_kwargs(),
290
+ **autotune_config.all_kwargs(),
249
291
  ))
250
292
  self.nargs = None
251
293
  return ret
@@ -263,27 +305,34 @@ class Config:
263
305
  :type num_warps: int
264
306
  :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops.
265
307
  Mostly useful for matrix multiplication workloads on SM80+ GPUs.
266
- :type num_ctas: int
308
+ :type num_stages: int
267
309
  :ivar num_ctas: number of blocks in a block cluster. SM90+ only.
310
+ :type num_ctas: int
268
311
  :type maxnreg: Optional[int]
269
312
  :ivar maxnreg: maximum number of registers one thread can use. Corresponds
270
313
  to ptx .maxnreg directive. Not supported on all platforms.
271
314
  :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this
272
315
  function are args.
316
+ :ivar ir_override: filename of a user-defined IR (*.{ttgir|llir|ptx|amdgcn}).
273
317
  """
274
318
 
275
- def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0,
276
- reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None):
319
+ def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None, ir_override=None):
277
320
  self.kwargs = kwargs
278
321
  self.num_warps = num_warps
279
322
  self.num_ctas = num_ctas
280
323
  self.num_stages = num_stages
281
- self.num_buffers_warp_spec = num_buffers_warp_spec
282
- self.num_consumer_groups = num_consumer_groups
283
- self.reg_dec_producer = reg_dec_producer
284
- self.reg_inc_consumer = reg_inc_consumer
285
324
  self.maxnreg = maxnreg
286
325
  self.pre_hook = pre_hook
326
+ self.ir_override = ir_override
327
+
328
+ def __setstate__(self, state):
329
+ self.kwargs = state.get("kwargs", {})
330
+ self.num_warps = state.get("num_warps", 4)
331
+ self.num_stages = state.get("num_stages", 3)
332
+ self.num_ctas = state.get("num_ctas", 1)
333
+ self.maxnreg = state.get("maxnreg", None)
334
+ self.pre_hook = state.get("pre_hook", None)
335
+ self.ir_override = state.get("ir_override", None)
287
336
 
288
337
  def all_kwargs(self):
289
338
  return {
@@ -293,11 +342,8 @@ class Config:
293
342
  ("num_warps", self.num_warps),
294
343
  ("num_ctas", self.num_ctas),
295
344
  ("num_stages", self.num_stages),
296
- ("num_buffers_warp_spec", self.num_buffers_warp_spec),
297
- ("num_consumer_groups", self.num_consumer_groups),
298
- ("reg_dec_producer", self.reg_dec_producer),
299
- ("reg_inc_consumer", self.reg_inc_consumer),
300
345
  ("maxnreg", self.maxnreg),
346
+ ("ir_override", self.ir_override),
301
347
  ) if v is not None
302
348
  }
303
349
  }
@@ -309,16 +355,26 @@ class Config:
309
355
  res.append(f"num_warps: {self.num_warps}")
310
356
  res.append(f"num_ctas: {self.num_ctas}")
311
357
  res.append(f"num_stages: {self.num_stages}")
312
- res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}")
313
- res.append(f"num_consumer_groups: {self.num_consumer_groups}")
314
- res.append(f"reg_dec_producer: {self.reg_dec_producer}")
315
- res.append(f"reg_inc_consumer: {self.reg_inc_consumer}")
316
358
  res.append(f"maxnreg: {self.maxnreg}")
317
359
  return ", ".join(res)
318
360
 
361
+ def __hash__(self):
362
+ return hash((*self.all_kwargs().items(), self.pre_hook))
363
+
364
+ def __eq__(self, other):
365
+ self_tuple = tuple((
366
+ *self.all_kwargs().items(),
367
+ self.pre_hook,
368
+ ))
369
+ other_tuple = tuple((
370
+ *other.all_kwargs().items(),
371
+ other.pre_hook,
372
+ ))
373
+ return self_tuple == other_tuple
374
+
319
375
 
320
376
  def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_value=None, pre_hook=None, post_hook=None,
321
- warmup=None, rep=None, use_cuda_graph=False, do_bench=None):
377
+ warmup=None, rep=None, use_cuda_graph=False, do_bench=None, cache_results=False):
322
378
  """
323
379
  Decorator for auto-tuning a :code:`triton.jit`'d function.
324
380
 
@@ -372,12 +428,14 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, restore_va
372
428
  :type rep: int
373
429
  :param do_bench: a benchmark function to measure the time of each run.
374
430
  :type do_bench: lambda fn, quantiles
431
+ :param cache_results: whether to cache autotune timings to disk. Defaults to False.
432
+ "type cache_results: bool
375
433
  """
376
434
 
377
435
  def decorator(fn):
378
436
  return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
379
437
  post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
380
- use_cuda_graph=use_cuda_graph, do_bench=do_bench)
438
+ use_cuda_graph=use_cuda_graph, do_bench=do_bench, cache_results=cache_results)
381
439
 
382
440
  return decorator
383
441