triton-windows 3.4.0.post20__cp312-cp312-win_amd64.whl → 3.5.0.post21__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of triton-windows might be problematic. Click here for more details.

Files changed (107) hide show
  1. triton/_C/libtriton.pyd +0 -0
  2. triton/__init__.py +8 -2
  3. triton/_filecheck.py +24 -14
  4. triton/_internal_testing.py +70 -4
  5. triton/_utils.py +3 -1
  6. triton/backends/amd/compiler.py +68 -60
  7. triton/backends/amd/driver.c +113 -44
  8. triton/backends/amd/driver.py +133 -57
  9. triton/backends/driver.py +13 -0
  10. triton/backends/nvidia/compiler.py +80 -22
  11. triton/backends/nvidia/driver.c +88 -15
  12. triton/backends/nvidia/driver.py +130 -123
  13. triton/compiler/__init__.py +5 -2
  14. triton/compiler/code_generator.py +270 -163
  15. triton/compiler/compiler.py +45 -62
  16. triton/experimental/gluon/__init__.py +3 -2
  17. triton/experimental/gluon/_runtime.py +9 -6
  18. triton/experimental/gluon/language/__init__.py +117 -16
  19. triton/experimental/gluon/language/_core.py +246 -68
  20. triton/experimental/gluon/language/_layouts.py +398 -45
  21. triton/experimental/gluon/language/_math.py +17 -9
  22. triton/experimental/gluon/language/_semantic.py +130 -37
  23. triton/experimental/gluon/language/_standard.py +55 -22
  24. triton/experimental/gluon/language/amd/__init__.py +4 -0
  25. triton/experimental/gluon/language/amd/_layouts.py +96 -0
  26. triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
  27. triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
  28. triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
  29. triton/experimental/gluon/language/extra/__init__.py +3 -0
  30. triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
  31. triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
  32. triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
  33. triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
  34. triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
  35. triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
  36. triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
  37. triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
  38. triton/experimental/gluon/nvidia/hopper.py +6 -1
  39. triton/knobs.py +132 -67
  40. triton/language/__init__.py +16 -10
  41. triton/language/core.py +163 -83
  42. triton/language/extra/cuda/gdc.py +6 -6
  43. triton/language/extra/hip/__init__.py +3 -1
  44. triton/language/extra/hip/libdevice.py +7 -0
  45. triton/language/extra/hip/utils.py +35 -0
  46. triton/language/extra/libdevice.py +4 -0
  47. triton/language/semantic.py +76 -23
  48. triton/language/standard.py +14 -14
  49. triton/language/target_info.py +54 -0
  50. triton/runtime/_allocation.py +15 -3
  51. triton/runtime/_async_compile.py +55 -0
  52. triton/runtime/autotuner.py +4 -5
  53. triton/runtime/build.py +11 -9
  54. triton/runtime/cache.py +44 -1
  55. triton/runtime/driver.py +16 -41
  56. triton/runtime/interpreter.py +31 -23
  57. triton/runtime/jit.py +318 -157
  58. triton/runtime/tcc/include/_mingw.h +8 -10
  59. triton/runtime/tcc/include/assert.h +5 -0
  60. triton/runtime/tcc/include/errno.h +1 -1
  61. triton/runtime/tcc/include/float.h +21 -3
  62. triton/runtime/tcc/include/iso646.h +36 -0
  63. triton/runtime/tcc/include/limits.h +5 -0
  64. triton/runtime/tcc/include/malloc.h +2 -2
  65. triton/runtime/tcc/include/math.h +21 -261
  66. triton/runtime/tcc/include/stdalign.h +16 -0
  67. triton/runtime/tcc/include/stdarg.h +5 -70
  68. triton/runtime/tcc/include/stdatomic.h +171 -0
  69. triton/runtime/tcc/include/stddef.h +7 -19
  70. triton/runtime/tcc/include/stdlib.h +15 -4
  71. triton/runtime/tcc/include/stdnoreturn.h +7 -0
  72. triton/runtime/tcc/include/sys/stat.h +2 -2
  73. triton/runtime/tcc/include/sys/types.h +5 -0
  74. triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
  75. triton/runtime/tcc/include/tccdefs.h +342 -0
  76. triton/runtime/tcc/include/tgmath.h +89 -0
  77. triton/runtime/tcc/include/uchar.h +33 -0
  78. triton/runtime/tcc/include/unistd.h +1 -0
  79. triton/runtime/tcc/include/winapi/qos.h +72 -0
  80. triton/runtime/tcc/include/winapi/shellapi.h +59 -0
  81. triton/runtime/tcc/include/winapi/winbase.h +9 -2
  82. triton/runtime/tcc/include/winapi/wincon.h +8 -0
  83. triton/runtime/tcc/include/winapi/windows.h +1 -1
  84. triton/runtime/tcc/include/winapi/winnls.h +778 -0
  85. triton/runtime/tcc/include/winapi/winnt.h +9 -7
  86. triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
  87. triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
  88. triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
  89. triton/runtime/tcc/lib/libtcc1.a +0 -0
  90. triton/runtime/tcc/lib/python314.def +1800 -0
  91. triton/runtime/tcc/lib/python314t.def +1809 -0
  92. triton/runtime/tcc/libtcc.dll +0 -0
  93. triton/runtime/tcc/tcc.exe +0 -0
  94. triton/tools/compile.py +62 -14
  95. triton/tools/extra/cuda/compile.c +1 -0
  96. triton/tools/extra/hip/compile.cpp +66 -0
  97. triton/tools/extra/hip/compile.h +13 -0
  98. triton/tools/ragged_tma.py +92 -0
  99. triton/tools/tensor_descriptor.py +7 -9
  100. triton/windows_utils.py +42 -79
  101. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
  102. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
  103. triton/runtime/tcc/lib/libtcc1-64.a +0 -0
  104. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
  105. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
  106. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
  107. {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
@@ -466,6 +466,10 @@ def fast_expf(arg0):
466
466
  ...
467
467
 
468
468
 
469
+ def fast_tanhf(arg0):
470
+ ...
471
+
472
+
469
473
  def fast_tanf(arg0):
470
474
  ...
471
475
 
@@ -219,7 +219,7 @@ class TritonSemantic(Generic[TensorTy]):
219
219
  min_value = self.scalar_constant(min_value, tl.int64)
220
220
  cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value))
221
221
  msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
222
- self.device_assert(cond, msg)
222
+ self.device_assert(cond, msg, None)
223
223
 
224
224
  def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
225
225
  sanitize_overflow: bool) -> TensorTy:
@@ -619,6 +619,9 @@ class TritonSemantic(Generic[TensorTy]):
619
619
  ret_ty = tl.block_type(value.dtype, shape)
620
620
  return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty)
621
621
 
622
+ def unsplat(self, value: TensorTy) -> TensorTy:
623
+ return self.tensor(self.builder.create_unsplat(value.handle), value.dtype)
624
+
622
625
  def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy:
623
626
  numel = 1
624
627
  for s in dst_shape:
@@ -1034,9 +1037,9 @@ class TritonSemantic(Generic[TensorTy]):
1034
1037
  # Make `mask` and `other` into the same shape as `ptr`
1035
1038
  if ptr.type.is_block():
1036
1039
  if mask is not None:
1037
- mask = self.broadcast_impl_shape(mask, ptr.type.get_block_shapes())
1040
+ ptr, mask = self.broadcast_impl_value(ptr, mask)
1038
1041
  if other is not None:
1039
- other = self.broadcast_impl_shape(other, ptr.type.get_block_shapes())
1042
+ ptr, other = self.broadcast_impl_value(ptr, other)
1040
1043
 
1041
1044
  # Get `pointer_type<elt_ty>` and `elt_ty`
1042
1045
  ptr_ty = ptr.type.scalar
@@ -1104,6 +1107,8 @@ class TritonSemantic(Generic[TensorTy]):
1104
1107
 
1105
1108
  def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
1106
1109
  self.validate_store_like(desc, value, offsets)
1110
+ # implicitly cast to the descriptor's type
1111
+ value = self.cast(value, desc.dtype)
1107
1112
  offsets = self._convert_to_ir_values(offsets, require_i64=False)
1108
1113
  return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
1109
1114
 
@@ -1472,10 +1477,10 @@ class TritonSemantic(Generic[TensorTy]):
1472
1477
  # All combinations of supported fp8 x fp8 are permitted
1473
1478
  pass
1474
1479
  else:
1475
- assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1476
- tl.float32), f"Unsupported lhs dtype {lhs.dtype}"
1477
- assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
1478
- tl.float32), f"Unsupported rhs dtype {rhs.dtype}"
1480
+ assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
1481
+ tl.float64), f"Unsupported lhs dtype {lhs.dtype}"
1482
+ assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
1483
+ tl.float64), f"Unsupported rhs dtype {rhs.dtype}"
1479
1484
  assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
1480
1485
 
1481
1486
  if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
@@ -1487,6 +1492,18 @@ class TritonSemantic(Generic[TensorTy]):
1487
1492
  lhs = self.cast(lhs, tl.float16)
1488
1493
  rhs = self.cast(rhs, tl.float16)
1489
1494
 
1495
+ uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8()
1496
+ uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16()
1497
+ if uses_fp8e4b8 or uses_fp8e5b16:
1498
+ type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16"
1499
+ if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes:
1500
+ arch = self.builder.options.arch
1501
+ warnings.warn(
1502
+ f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. "
1503
+ f"Please use OCP fp8 variants on {arch} for performance")
1504
+ lhs = self.cast(lhs, tl.float16)
1505
+ rhs = self.cast(rhs, tl.float16)
1506
+
1490
1507
  if input_precision is None:
1491
1508
  input_precision = self.builder.options.default_dot_input_precision
1492
1509
 
@@ -1514,6 +1531,9 @@ class TritonSemantic(Generic[TensorTy]):
1514
1531
  elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
1515
1532
  _0 = self.builder.get_fp32(0)
1516
1533
  ret_scalar_ty = tl.float32
1534
+ elif lhs.type.scalar.is_fp64():
1535
+ _0 = self.builder.get_fp64(0)
1536
+ ret_scalar_ty = tl.float64
1517
1537
  else:
1518
1538
  _0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0)
1519
1539
  ret_scalar_ty = out_dtype
@@ -1527,7 +1547,7 @@ class TritonSemantic(Generic[TensorTy]):
1527
1547
  acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
1528
1548
  else:
1529
1549
  acc_handle = acc.handle
1530
- assert acc.type == ret_ty
1550
+ assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
1531
1551
 
1532
1552
  # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
1533
1553
  if max_num_imprecise_acc is None:
@@ -1607,7 +1627,7 @@ class TritonSemantic(Generic[TensorTy]):
1607
1627
  acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
1608
1628
  else:
1609
1629
  acc_handle = acc.handle
1610
- assert acc.type == ret_ty
1630
+ assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
1611
1631
  rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
1612
1632
  lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
1613
1633
  return self.tensor(
@@ -1709,6 +1729,36 @@ class TritonSemantic(Generic[TensorTy]):
1709
1729
  gather = self.builder.create_gather(src.handle, index.handle, axis)
1710
1730
  return self.wrap_tensor(gather, src.type.scalar, index.type.shape)
1711
1731
 
1732
+ # ===----------------------------------------------------------------------===
1733
+ # Map Elementwise
1734
+ # ===----------------------------------------------------------------------===
1735
+
1736
+ def broadcast_tensors(self, *inputs):
1737
+ if not inputs:
1738
+ return ()
1739
+ head, *tail = inputs
1740
+ for i in range(len(tail)):
1741
+ head, tail[i] = self.broadcast_impl_value(head, tail[i])
1742
+ for i in range(len(tail)):
1743
+ head, tail[i] = self.broadcast_impl_value(head, tail[i])
1744
+ return (head, *tail)
1745
+
1746
+ def map_elementwise(self, inputs: Sequence[tl.tensor], result_types: Sequence[tl.dtype], pack: int,
1747
+ region_builder_fn) -> Tuple[tl.tensor, ...]:
1748
+ inputs = self.broadcast_tensors(*inputs)
1749
+
1750
+ assert len(inputs) > 0, "map_elementwise must have at least 1 input tensor"
1751
+ result_types = [inputs[0].type.with_element_ty(ty.scalar) for ty in result_types]
1752
+ elementwise_op = self.builder.create_map_elementwise(
1753
+ [t.handle for t in inputs],
1754
+ [ty.to_ir(self.builder) for ty in result_types],
1755
+ pack,
1756
+ )
1757
+ region_builder_fn(elementwise_op)
1758
+ # assert elementwise_op.verify()
1759
+
1760
+ return tuple(self.tensor(elementwise_op.get_result(i), ty) for i, ty in enumerate(result_types))
1761
+
1712
1762
 
1713
1763
  # ===----------------------------------------------------------------------===
1714
1764
  # Histogram
@@ -1760,9 +1810,11 @@ class TritonSemantic(Generic[TensorTy]):
1760
1810
  is_signed = [arg.dtype.is_int_signed() for arg in args]
1761
1811
  return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void)
1762
1812
 
1763
- def device_assert(self, cond: TensorTy, msg: str) -> TensorTy:
1813
+ def device_assert(self, cond: TensorTy, msg: str, mask: Optional[TensorTy]) -> TensorTy:
1764
1814
  if not self.builder.options.debug:
1765
1815
  return
1816
+ if mask is not None:
1817
+ cond = self.or_(cond, self.not_(mask))
1766
1818
  return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void)
1767
1819
 
1768
1820
  def assume(self, cond) -> TensorTy:
@@ -1788,7 +1840,7 @@ class TritonSemantic(Generic[TensorTy]):
1788
1840
  if elem.dtype != tl.int64 and require_i64:
1789
1841
  return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(),
1790
1842
  elem.dtype.is_int_signed())
1791
- elif elem.dtype != tl.int32 and not require_i64:
1843
+ elif elem.dtype == tl.int64 and not require_i64:
1792
1844
  assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
1793
1845
  "add a `.to(tl.int32)` or use regular indexing for 64 bit support"
1794
1846
  return elem.handle
@@ -1844,13 +1896,8 @@ class TritonSemantic(Generic[TensorTy]):
1844
1896
  # Advanced block pointer type is the same as before
1845
1897
  return self.tensor(self.builder.create_advance(base.handle, offsets), base.type)
1846
1898
 
1847
- def make_tensor_descriptor(
1848
- self,
1849
- base: TensorTy,
1850
- shape: List[TensorTy],
1851
- strides: List[TensorTy],
1852
- block_shape: List[tl.constexpr],
1853
- ) -> tl.tensor_descriptor:
1899
+ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: List[TensorTy],
1900
+ block_shape: List[tl.constexpr], padding_option: str = "zero") -> tl.tensor_descriptor:
1854
1901
  ndim = len(shape)
1855
1902
  if not (1 <= ndim <= 5):
1856
1903
  raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions")
@@ -1866,12 +1913,12 @@ class TritonSemantic(Generic[TensorTy]):
1866
1913
  f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
1867
1914
  )
1868
1915
 
1869
- strides[-1] = tl._unwrap_if_constexpr(strides[-1])
1870
- if strides[-1] != 1:
1871
- raise ValueError(f"Tensor descriptor last dim must be 1 but got {strides[-1]}")
1916
+ last_stride = tl._unwrap_if_constexpr(strides[-1])
1917
+ if last_stride != 1:
1918
+ raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}")
1872
1919
 
1873
1920
  shape = [self.make_scalar(x, tl.int32) for x in shape]
1874
- strides = [self.make_scalar(x, tl.int64) for x in strides]
1921
+ strides = [self.make_scalar(tl._unwrap_if_constexpr(x), tl.int64) for x in strides]
1875
1922
 
1876
1923
  # Check whether `block_shape` is static
1877
1924
  block_shape = tl._unwrap_shape(block_shape)
@@ -1881,6 +1928,12 @@ class TritonSemantic(Generic[TensorTy]):
1881
1928
  base_handle = base.handle
1882
1929
  is_signed_int = base.type.element_ty.is_int_signed()
1883
1930
 
1931
+ padding = self._str_to_padding_option(padding_option)
1932
+
1933
+ if base.type.element_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
1934
+ raise ValueError("Padding option `nan` is not supported for integer blocks")
1935
+
1884
1936
  handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape],
1885
- [s.handle for s in strides], block_shape, is_signed_int)
1937
+ [s.handle for s in strides], block_shape, is_signed_int,
1938
+ padding)
1886
1939
  return tl.tensor_descriptor(handle, shape, strides, type)
@@ -1,24 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
- from ..runtime.jit import jit
3
+ from ..runtime.jit import jit, constexpr_function
4
4
  from . import core
5
5
  from . import math
6
6
 
7
7
  # constexpr utilities
8
8
 
9
9
 
10
- def _log2(i: core.constexpr):
10
+ @constexpr_function
11
+ def _log2(i):
11
12
  log2 = 0
12
- n = core.constexpr(i).value
13
+ n = i
13
14
  while n > 1:
14
15
  n >>= 1
15
16
  log2 += 1
16
- return core.constexpr(log2)
17
+ return log2
17
18
 
18
19
 
19
- def _is_power_of_two(i: core.constexpr):
20
- n = i.value
21
- return core.constexpr((n & (n - 1)) == 0 and n != 0)
20
+ @constexpr_function
21
+ def _is_power_of_two(i):
22
+ return (i & (i - 1)) == 0 and i != 0
22
23
 
23
24
 
24
25
  # -----------------------
@@ -263,8 +264,8 @@ def _sum_combine(a, b):
263
264
  # sum
264
265
 
265
266
 
266
- def _pick_sum_dtype(in_dtype: core.constexpr, dtype: core.constexpr):
267
- dtype = core._unwrap_if_constexpr(dtype)
267
+ @constexpr_function
268
+ def _pick_sum_dtype(in_dtype, dtype):
268
269
  if dtype is not None:
269
270
  return dtype
270
271
 
@@ -316,9 +317,9 @@ def _or_combine(x, y):
316
317
 
317
318
  @core._tensor_member_fn
318
319
  @jit
319
- @core._add_reduction_docstr("reduce_of")
320
+ @core._add_reduction_docstr("reduce_or")
320
321
  def reduce_or(input, axis, keep_dims=False):
321
- core.static_assert(input.type.scalar.is_int(), "reduce_of only supported for integers")
322
+ core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
322
323
  return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
323
324
 
324
325
 
@@ -476,14 +477,13 @@ def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = co
476
477
  return _bitonic_merge(x, n_dims, descending, n_dims)
477
478
 
478
479
 
480
+ @constexpr_function
479
481
  def _get_flip_dim(dim, shape):
480
- dim = core._unwrap_if_constexpr(dim)
481
- shape = core._unwrap_if_constexpr(shape)
482
482
  if dim is None:
483
483
  dim = len(shape) - 1
484
484
  if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
485
485
  dim += len(shape)
486
- return core.constexpr(dim)
486
+ return dim
487
487
 
488
488
 
489
489
  @core._tensor_member_fn
@@ -0,0 +1,54 @@
1
+ from triton.runtime import driver
2
+ from triton.runtime.jit import constexpr_function
3
+
4
+ __all__ = ["current_target"]
5
+
6
+
7
+ def current_target():
8
+ try:
9
+ active_driver = driver.active
10
+ except RuntimeError:
11
+ # If there is no active driver, return None
12
+ return None
13
+ return active_driver.get_current_target()
14
+
15
+
16
+ current_target.__triton_builtin__ = True
17
+
18
+
19
+ @constexpr_function
20
+ def is_cuda():
21
+ target = current_target()
22
+ return target is not None and target.backend == "cuda"
23
+
24
+
25
+ @constexpr_function
26
+ def cuda_capability_geq(major, minor=0):
27
+ """
28
+ Determines whether we have compute capability >= (major, minor) and
29
+ returns this as a constexpr boolean. This can be used for guarding
30
+ inline asm implementations that require a certain compute capability.
31
+ """
32
+ target = current_target()
33
+ if target is None or target.backend != "cuda":
34
+ return False
35
+ assert isinstance(target.arch, int)
36
+ return target.arch >= major * 10 + minor
37
+
38
+
39
+ @constexpr_function
40
+ def is_hip():
41
+ target = current_target()
42
+ return target is not None and target.backend == "hip"
43
+
44
+
45
+ @constexpr_function
46
+ def is_hip_cdna3():
47
+ target = current_target()
48
+ return target is not None and target.arch == "gfx942"
49
+
50
+
51
+ @constexpr_function
52
+ def is_hip_cdna4():
53
+ target = current_target()
54
+ return target is not None and target.arch == "gfx950"
@@ -1,4 +1,5 @@
1
1
  from typing import Optional, Protocol
2
+ from contextvars import ContextVar
2
3
 
3
4
 
4
5
  class Buffer(Protocol):
@@ -20,7 +21,7 @@ class NullAllocator:
20
21
  "Use triton.set_allocator to specify an allocator.")
21
22
 
22
23
 
23
- _allocator: Allocator = NullAllocator()
24
+ _allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
24
25
 
25
26
 
26
27
  def set_allocator(allocator: Allocator):
@@ -28,5 +29,16 @@ def set_allocator(allocator: Allocator):
28
29
  The allocator function is called during kernel launch for kernels that
29
30
  require additional global memory workspace.
30
31
  """
31
- global _allocator
32
- _allocator = allocator
32
+ _allocator.set(allocator)
33
+
34
+
35
+ _profile_allocator: Allocator = ContextVar("_allocator", default=NullAllocator())
36
+
37
+
38
+ def set_profile_allocator(allocator: Optional[Allocator]):
39
+ """
40
+ The profile allocator function is called before kernel launch for kernels
41
+ that require additional global memory workspace.
42
+ """
43
+ global _profile_allocator
44
+ _profile_allocator.set(allocator)
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+ from typing import Callable, Optional
3
+ from concurrent.futures import Executor, as_completed, Future
4
+ from contextvars import ContextVar
5
+
6
+ active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
7
+
8
+
9
+ class FutureKernel:
10
+
11
+ def __init__(self, finalize_compile: Callable, future: Future):
12
+ self.finalize_compile = finalize_compile
13
+ self.kernel = None
14
+ self.future = future
15
+
16
+ def result(self):
17
+ if self.kernel is not None:
18
+ return self.kernel
19
+
20
+ kernel = self.future.result()
21
+ self.finalize_compile(kernel)
22
+ self.kernel = kernel
23
+ return kernel
24
+
25
+
26
+ class AsyncCompileMode:
27
+
28
+ def __init__(self, executor: Executor):
29
+ self.executor = executor
30
+ self.raw_futures = []
31
+ self.future_kernels = {}
32
+
33
+ def submit(self, key, compile_fn, finalize_fn):
34
+ future = self.future_kernels.get(key)
35
+ if future is not None:
36
+ return future
37
+
38
+ future = self.executor.submit(compile_fn)
39
+ future._key = key
40
+ self.raw_futures.append(future)
41
+ future_kernel = FutureKernel(finalize_fn, future)
42
+ self.future_kernels[key] = future_kernel
43
+ return future_kernel
44
+
45
+ def __enter__(self):
46
+ if active_mode.get() is not None:
47
+ raise RuntimeError("Another AsyncCompileMode is already active")
48
+ active_mode.set(self)
49
+ return self
50
+
51
+ def __exit__(self, exc_type, exc_value, traceback):
52
+ # Finalize any outstanding compiles
53
+ for future in as_completed(self.raw_futures):
54
+ self.future_kernels[future._key].result()
55
+ active_mode.set(None)
@@ -9,9 +9,11 @@ from functools import cached_property
9
9
  from typing import Dict, Tuple, List, Optional
10
10
 
11
11
  from .. import knobs
12
- from .jit import KernelInterface
12
+ from .jit import KernelInterface, JITFunction
13
13
  from .errors import OutOfResources, PTXASError
14
14
  from .driver import driver
15
+ from .cache import get_cache_manager, triton_key
16
+ from triton._C.libtriton import get_cache_invalidating_env_vars
15
17
 
16
18
 
17
19
  class Autotuner(KernelInterface):
@@ -169,10 +171,7 @@ class Autotuner(KernelInterface):
169
171
  bench_fn()
170
172
  return False
171
173
 
172
- from triton._C.libtriton import get_cache_invalidating_env_vars
173
- from triton.compiler.compiler import make_backend, triton_key
174
- from triton.runtime.cache import get_cache_manager
175
- from triton.runtime.jit import JITFunction
174
+ from triton.compiler.compiler import make_backend
176
175
 
177
176
  fn = self.fn
178
177
  while not isinstance(fn, JITFunction):
triton/runtime/build.py CHANGED
@@ -56,10 +56,11 @@ def is_clang(cc):
56
56
  return cc == "clang" or cc == "clang.exe"
57
57
 
58
58
 
59
- def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
59
+ def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str],
60
+ ccflags: list[str]) -> list[str]:
60
61
  if is_msvc(cc):
61
62
  out_base = os.path.splitext(out)[0]
62
- cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819"]
63
+ cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/std:c11", "/wd4819"]
63
64
  cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
64
65
  cc_cmd += [f"/Fo{out_base + '.obj'}"]
65
66
  cc_cmd += ["/link"]
@@ -79,16 +80,16 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
79
80
  cc_cmd += [f'-l{lib}' for lib in libraries]
80
81
  cc_cmd += [f"-L{dir}" for dir in library_dirs]
81
82
  cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
83
+ cc_cmd += ccflags
82
84
  return cc_cmd
83
85
 
84
86
 
85
- def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
86
- libraries: list[str]) -> str:
87
+ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
88
+ ccflags: list[str]) -> str:
87
89
  if impl := knobs.build.impl:
88
90
  return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
89
91
  suffix = sysconfig.get_config_var('EXT_SUFFIX')
90
92
  so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
91
- # try to avoid setuptools if possible
92
93
  cc = get_cc()
93
94
  # This function was renamed and made public in Python 3.10
94
95
  if hasattr(sysconfig, 'get_default_scheme'):
@@ -113,10 +114,10 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
113
114
  _, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
114
115
  include_dirs = include_dirs + msvc_winsdk_inc_dirs
115
116
  library_dirs = library_dirs + msvc_winsdk_lib_dirs
116
- cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
117
+ cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries, ccflags)
117
118
 
118
119
  try:
119
- ret = subprocess.check_call(cc_cmd)
120
+ subprocess.check_call(cc_cmd)
120
121
  except Exception as e:
121
122
  print("Failed to compile. cc_cmd:", cc_cmd)
122
123
  raise e
@@ -142,7 +143,8 @@ def _load_module_from_path(name: str, path: str) -> ModuleType:
142
143
 
143
144
 
144
145
  def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
145
- include_dirs: list[str] | None = None, libraries: list[str] | None = None) -> ModuleType:
146
+ include_dirs: list[str] | None = None, libraries: list[str] | None = None,
147
+ ccflags: list[str] | None = None) -> ModuleType:
146
148
  key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
147
149
  cache = get_cache_manager(key)
148
150
  suffix = sysconfig.get_config_var("EXT_SUFFIX")
@@ -159,7 +161,7 @@ def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None
159
161
  src_path = os.path.join(tmpdir, name + ".c")
160
162
  with open(src_path, "w") as f:
161
163
  f.write(src)
162
- so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [])
164
+ so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
163
165
  with open(so, "rb") as f:
164
166
  cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
165
167
 
triton/runtime/cache.py CHANGED
@@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
5
5
  from typing import Dict, List, Optional
6
6
  import base64
7
7
  import hashlib
8
+ import functools
9
+ import sysconfig
8
10
 
9
- from .. import knobs
11
+ from triton import __version__, knobs
10
12
 
11
13
 
12
14
  class CacheManager(ABC):
@@ -272,3 +274,44 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
272
274
  key = f"{key}-{kwargs.get(kw)}"
273
275
  key = hashlib.sha256(key.encode("utf-8")).hexdigest()
274
276
  return _base32(key)
277
+
278
+
279
+ @functools.lru_cache()
280
+ def triton_key():
281
+ import pkgutil
282
+ TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
283
+ contents = []
284
+ # frontend
285
+ with open(__file__, "rb") as f:
286
+ contents += [hashlib.sha256(f.read()).hexdigest()]
287
+ # compiler
288
+ path_prefixes = [
289
+ (os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
290
+ (os.path.join(TRITON_PATH, "backends"), "triton.backends."),
291
+ ]
292
+ for path, prefix in path_prefixes:
293
+ for lib in pkgutil.walk_packages([path], prefix=prefix):
294
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
295
+ contents += [hashlib.sha256(f.read()).hexdigest()]
296
+
297
+ # backend
298
+ libtriton_hash = hashlib.sha256()
299
+ ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
300
+ with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
301
+ while True:
302
+ chunk = f.read(1024**2)
303
+ if not chunk:
304
+ break
305
+ libtriton_hash.update(chunk)
306
+ contents.append(libtriton_hash.hexdigest())
307
+ # language
308
+ language_path = os.path.join(TRITON_PATH, 'language')
309
+ for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
310
+ with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
311
+ contents += [hashlib.sha256(f.read()).hexdigest()]
312
+ return f'{__version__}' + '-'.join(contents)
313
+
314
+
315
+ def get_cache_key(src, backend, backend_options, env_vars):
316
+ key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
317
+ return key
triton/runtime/driver.py CHANGED
@@ -2,8 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  from ..backends import backends, DriverBase
4
4
 
5
- from typing import Any, Callable, Generic, TypeVar, Union
6
-
7
5
 
8
6
  def _create_driver() -> DriverBase:
9
7
  active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
@@ -12,52 +10,29 @@ def _create_driver() -> DriverBase:
12
10
  return active_drivers[0]()
13
11
 
14
12
 
15
- T = TypeVar("T")
16
-
17
-
18
- class LazyProxy(Generic[T]):
19
-
20
- def __init__(self, init_fn: Callable[[], T]) -> None:
21
- self._init_fn = init_fn
22
- self._obj: Union[T, None] = None
23
-
24
- def _initialize_obj(self) -> T:
25
- if self._obj is None:
26
- self._obj = self._init_fn()
27
- return self._obj
28
-
29
- def __getattr__(self, name) -> Any:
30
- return getattr(self._initialize_obj(), name)
31
-
32
- def __setattr__(self, name: str, value: Any) -> None:
33
- if name in ["_init_fn", "_obj"]:
34
- super().__setattr__(name, value)
35
- else:
36
- setattr(self._initialize_obj(), name, value)
37
-
38
- def __delattr__(self, name: str) -> None:
39
- delattr(self._initialize_obj(), name)
40
-
41
- def __repr__(self) -> str:
42
- if self._obj is None:
43
- return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
44
- return repr(self._obj)
45
-
46
- def __str__(self) -> str:
47
- return str(self._initialize_obj())
48
-
49
-
50
13
  class DriverConfig:
51
14
 
52
15
  def __init__(self) -> None:
53
- self.default: LazyProxy[DriverBase] = LazyProxy(_create_driver)
54
- self.active: Union[LazyProxy[DriverBase], DriverBase] = self.default
16
+ self._default: DriverBase | None = None
17
+ self._active: DriverBase | None = None
18
+
19
+ @property
20
+ def default(self) -> DriverBase:
21
+ if self._default is None:
22
+ self._default = _create_driver()
23
+ return self._default
24
+
25
+ @property
26
+ def active(self) -> DriverBase:
27
+ if self._active is None:
28
+ self._active = self.default
29
+ return self._active
55
30
 
56
31
  def set_active(self, driver: DriverBase) -> None:
57
- self.active = driver
32
+ self._active = driver
58
33
 
59
34
  def reset_active(self) -> None:
60
- self.active = self.default
35
+ self._active = self.default
61
36
 
62
37
 
63
38
  driver = DriverConfig()