triton-windows 3.4.0.post20__cp313-cp313-win_amd64.whl → 3.5.0.post21__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/language/semantic.py
CHANGED
|
@@ -219,7 +219,7 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
219
219
|
min_value = self.scalar_constant(min_value, tl.int64)
|
|
220
220
|
cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value))
|
|
221
221
|
msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
|
|
222
|
-
self.device_assert(cond, msg)
|
|
222
|
+
self.device_assert(cond, msg, None)
|
|
223
223
|
|
|
224
224
|
def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
|
|
225
225
|
sanitize_overflow: bool) -> TensorTy:
|
|
@@ -619,6 +619,9 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
619
619
|
ret_ty = tl.block_type(value.dtype, shape)
|
|
620
620
|
return self.tensor(self.builder.create_splat(ret_ty.to_ir(self.builder), value.handle), ret_ty)
|
|
621
621
|
|
|
622
|
+
def unsplat(self, value: TensorTy) -> TensorTy:
|
|
623
|
+
return self.tensor(self.builder.create_unsplat(value.handle), value.dtype)
|
|
624
|
+
|
|
622
625
|
def reshape(self, input: TensorTy, dst_shape: List[int], can_reorder: bool) -> TensorTy:
|
|
623
626
|
numel = 1
|
|
624
627
|
for s in dst_shape:
|
|
@@ -1034,9 +1037,9 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1034
1037
|
# Make `mask` and `other` into the same shape as `ptr`
|
|
1035
1038
|
if ptr.type.is_block():
|
|
1036
1039
|
if mask is not None:
|
|
1037
|
-
mask = self.
|
|
1040
|
+
ptr, mask = self.broadcast_impl_value(ptr, mask)
|
|
1038
1041
|
if other is not None:
|
|
1039
|
-
other = self.
|
|
1042
|
+
ptr, other = self.broadcast_impl_value(ptr, other)
|
|
1040
1043
|
|
|
1041
1044
|
# Get `pointer_type<elt_ty>` and `elt_ty`
|
|
1042
1045
|
ptr_ty = ptr.type.scalar
|
|
@@ -1104,6 +1107,8 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1104
1107
|
|
|
1105
1108
|
def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
|
|
1106
1109
|
self.validate_store_like(desc, value, offsets)
|
|
1110
|
+
# implicitly cast to the descriptor's type
|
|
1111
|
+
value = self.cast(value, desc.dtype)
|
|
1107
1112
|
offsets = self._convert_to_ir_values(offsets, require_i64=False)
|
|
1108
1113
|
return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
|
|
1109
1114
|
|
|
@@ -1472,10 +1477,10 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1472
1477
|
# All combinations of supported fp8 x fp8 are permitted
|
|
1473
1478
|
pass
|
|
1474
1479
|
else:
|
|
1475
|
-
assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
|
|
1476
|
-
tl.
|
|
1477
|
-
assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16,
|
|
1478
|
-
tl.
|
|
1480
|
+
assert lhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
|
|
1481
|
+
tl.float64), f"Unsupported lhs dtype {lhs.dtype}"
|
|
1482
|
+
assert rhs.dtype in (tl.int8, tl.uint8, tl.float16, tl.bfloat16, tl.float32,
|
|
1483
|
+
tl.float64), f"Unsupported rhs dtype {rhs.dtype}"
|
|
1479
1484
|
assert lhs.dtype == rhs.dtype, f"Both operands must be same dtype. Got {lhs.dtype} and {rhs.dtype}"
|
|
1480
1485
|
|
|
1481
1486
|
if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():
|
|
@@ -1487,6 +1492,18 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1487
1492
|
lhs = self.cast(lhs, tl.float16)
|
|
1488
1493
|
rhs = self.cast(rhs, tl.float16)
|
|
1489
1494
|
|
|
1495
|
+
uses_fp8e4b8 = lhs.dtype.is_fp8e4b8() or rhs.dtype.is_fp8e4b8()
|
|
1496
|
+
uses_fp8e5b16 = lhs.dtype.is_fp8e5b16() or rhs.dtype.is_fp8e5b16()
|
|
1497
|
+
if uses_fp8e4b8 or uses_fp8e5b16:
|
|
1498
|
+
type_name = "fp8e4b8" if uses_fp8e4b8 else "fp8e5b16"
|
|
1499
|
+
if type_name in self.builder.options.deprecated_fp8_dot_operand_dtypes:
|
|
1500
|
+
arch = self.builder.options.arch
|
|
1501
|
+
warnings.warn(
|
|
1502
|
+
f"{type_name} is AMD gfx942 specific and not supported on {arch} so it's upcasted to fp16 and can cause significant slow down. "
|
|
1503
|
+
f"Please use OCP fp8 variants on {arch} for performance")
|
|
1504
|
+
lhs = self.cast(lhs, tl.float16)
|
|
1505
|
+
rhs = self.cast(rhs, tl.float16)
|
|
1506
|
+
|
|
1490
1507
|
if input_precision is None:
|
|
1491
1508
|
input_precision = self.builder.options.default_dot_input_precision
|
|
1492
1509
|
|
|
@@ -1514,6 +1531,9 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1514
1531
|
elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16():
|
|
1515
1532
|
_0 = self.builder.get_fp32(0)
|
|
1516
1533
|
ret_scalar_ty = tl.float32
|
|
1534
|
+
elif lhs.type.scalar.is_fp64():
|
|
1535
|
+
_0 = self.builder.get_fp64(0)
|
|
1536
|
+
ret_scalar_ty = tl.float64
|
|
1517
1537
|
else:
|
|
1518
1538
|
_0 = self.builder.get_fp16(0) if out_dtype.is_fp16() else self.builder.get_fp32(0)
|
|
1519
1539
|
ret_scalar_ty = out_dtype
|
|
@@ -1527,7 +1547,7 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1527
1547
|
acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
|
|
1528
1548
|
else:
|
|
1529
1549
|
acc_handle = acc.handle
|
|
1530
|
-
assert acc.type == ret_ty
|
|
1550
|
+
assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
|
|
1531
1551
|
|
|
1532
1552
|
# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
|
|
1533
1553
|
if max_num_imprecise_acc is None:
|
|
@@ -1607,7 +1627,7 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1607
1627
|
acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
|
|
1608
1628
|
else:
|
|
1609
1629
|
acc_handle = acc.handle
|
|
1610
|
-
assert acc.type == ret_ty
|
|
1630
|
+
assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
|
|
1611
1631
|
rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
|
|
1612
1632
|
lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle
|
|
1613
1633
|
return self.tensor(
|
|
@@ -1709,6 +1729,36 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1709
1729
|
gather = self.builder.create_gather(src.handle, index.handle, axis)
|
|
1710
1730
|
return self.wrap_tensor(gather, src.type.scalar, index.type.shape)
|
|
1711
1731
|
|
|
1732
|
+
# ===----------------------------------------------------------------------===
|
|
1733
|
+
# Map Elementwise
|
|
1734
|
+
# ===----------------------------------------------------------------------===
|
|
1735
|
+
|
|
1736
|
+
def broadcast_tensors(self, *inputs):
|
|
1737
|
+
if not inputs:
|
|
1738
|
+
return ()
|
|
1739
|
+
head, *tail = inputs
|
|
1740
|
+
for i in range(len(tail)):
|
|
1741
|
+
head, tail[i] = self.broadcast_impl_value(head, tail[i])
|
|
1742
|
+
for i in range(len(tail)):
|
|
1743
|
+
head, tail[i] = self.broadcast_impl_value(head, tail[i])
|
|
1744
|
+
return (head, *tail)
|
|
1745
|
+
|
|
1746
|
+
def map_elementwise(self, inputs: Sequence[tl.tensor], result_types: Sequence[tl.dtype], pack: int,
|
|
1747
|
+
region_builder_fn) -> Tuple[tl.tensor, ...]:
|
|
1748
|
+
inputs = self.broadcast_tensors(*inputs)
|
|
1749
|
+
|
|
1750
|
+
assert len(inputs) > 0, "map_elementwise must have at least 1 input tensor"
|
|
1751
|
+
result_types = [inputs[0].type.with_element_ty(ty.scalar) for ty in result_types]
|
|
1752
|
+
elementwise_op = self.builder.create_map_elementwise(
|
|
1753
|
+
[t.handle for t in inputs],
|
|
1754
|
+
[ty.to_ir(self.builder) for ty in result_types],
|
|
1755
|
+
pack,
|
|
1756
|
+
)
|
|
1757
|
+
region_builder_fn(elementwise_op)
|
|
1758
|
+
# assert elementwise_op.verify()
|
|
1759
|
+
|
|
1760
|
+
return tuple(self.tensor(elementwise_op.get_result(i), ty) for i, ty in enumerate(result_types))
|
|
1761
|
+
|
|
1712
1762
|
|
|
1713
1763
|
# ===----------------------------------------------------------------------===
|
|
1714
1764
|
# Histogram
|
|
@@ -1760,9 +1810,11 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1760
1810
|
is_signed = [arg.dtype.is_int_signed() for arg in args]
|
|
1761
1811
|
return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void)
|
|
1762
1812
|
|
|
1763
|
-
def device_assert(self, cond: TensorTy, msg: str) -> TensorTy:
|
|
1813
|
+
def device_assert(self, cond: TensorTy, msg: str, mask: Optional[TensorTy]) -> TensorTy:
|
|
1764
1814
|
if not self.builder.options.debug:
|
|
1765
1815
|
return
|
|
1816
|
+
if mask is not None:
|
|
1817
|
+
cond = self.or_(cond, self.not_(mask))
|
|
1766
1818
|
return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void)
|
|
1767
1819
|
|
|
1768
1820
|
def assume(self, cond) -> TensorTy:
|
|
@@ -1788,7 +1840,7 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1788
1840
|
if elem.dtype != tl.int64 and require_i64:
|
|
1789
1841
|
return self.builder.create_int_cast(elem.handle, self.builder.get_int64_ty(),
|
|
1790
1842
|
elem.dtype.is_int_signed())
|
|
1791
|
-
elif elem.dtype
|
|
1843
|
+
elif elem.dtype == tl.int64 and not require_i64:
|
|
1792
1844
|
assert False, "Block pointers only support 32 bit `offsets/block_shape`, " \
|
|
1793
1845
|
"add a `.to(tl.int32)` or use regular indexing for 64 bit support"
|
|
1794
1846
|
return elem.handle
|
|
@@ -1844,13 +1896,8 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1844
1896
|
# Advanced block pointer type is the same as before
|
|
1845
1897
|
return self.tensor(self.builder.create_advance(base.handle, offsets), base.type)
|
|
1846
1898
|
|
|
1847
|
-
def make_tensor_descriptor(
|
|
1848
|
-
|
|
1849
|
-
base: TensorTy,
|
|
1850
|
-
shape: List[TensorTy],
|
|
1851
|
-
strides: List[TensorTy],
|
|
1852
|
-
block_shape: List[tl.constexpr],
|
|
1853
|
-
) -> tl.tensor_descriptor:
|
|
1899
|
+
def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: List[TensorTy],
|
|
1900
|
+
block_shape: List[tl.constexpr], padding_option: str = "zero") -> tl.tensor_descriptor:
|
|
1854
1901
|
ndim = len(shape)
|
|
1855
1902
|
if not (1 <= ndim <= 5):
|
|
1856
1903
|
raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions")
|
|
@@ -1866,12 +1913,12 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1866
1913
|
f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes"
|
|
1867
1914
|
)
|
|
1868
1915
|
|
|
1869
|
-
|
|
1870
|
-
if
|
|
1871
|
-
raise ValueError(f"Tensor descriptor last dim must be 1 but got {
|
|
1916
|
+
last_stride = tl._unwrap_if_constexpr(strides[-1])
|
|
1917
|
+
if last_stride != 1:
|
|
1918
|
+
raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}")
|
|
1872
1919
|
|
|
1873
1920
|
shape = [self.make_scalar(x, tl.int32) for x in shape]
|
|
1874
|
-
strides = [self.make_scalar(x, tl.int64) for x in strides]
|
|
1921
|
+
strides = [self.make_scalar(tl._unwrap_if_constexpr(x), tl.int64) for x in strides]
|
|
1875
1922
|
|
|
1876
1923
|
# Check whether `block_shape` is static
|
|
1877
1924
|
block_shape = tl._unwrap_shape(block_shape)
|
|
@@ -1881,6 +1928,12 @@ class TritonSemantic(Generic[TensorTy]):
|
|
|
1881
1928
|
base_handle = base.handle
|
|
1882
1929
|
is_signed_int = base.type.element_ty.is_int_signed()
|
|
1883
1930
|
|
|
1931
|
+
padding = self._str_to_padding_option(padding_option)
|
|
1932
|
+
|
|
1933
|
+
if base.type.element_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN:
|
|
1934
|
+
raise ValueError("Padding option `nan` is not supported for integer blocks")
|
|
1935
|
+
|
|
1884
1936
|
handle = self.builder.create_make_tensor_descriptor(base_handle, [s.handle for s in shape],
|
|
1885
|
-
[s.handle for s in strides], block_shape, is_signed_int
|
|
1937
|
+
[s.handle for s in strides], block_shape, is_signed_int,
|
|
1938
|
+
padding)
|
|
1886
1939
|
return tl.tensor_descriptor(handle, shape, strides, type)
|
triton/language/standard.py
CHANGED
|
@@ -1,24 +1,25 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from ..runtime.jit import jit
|
|
3
|
+
from ..runtime.jit import jit, constexpr_function
|
|
4
4
|
from . import core
|
|
5
5
|
from . import math
|
|
6
6
|
|
|
7
7
|
# constexpr utilities
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
@constexpr_function
|
|
11
|
+
def _log2(i):
|
|
11
12
|
log2 = 0
|
|
12
|
-
n =
|
|
13
|
+
n = i
|
|
13
14
|
while n > 1:
|
|
14
15
|
n >>= 1
|
|
15
16
|
log2 += 1
|
|
16
|
-
return
|
|
17
|
+
return log2
|
|
17
18
|
|
|
18
19
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
return
|
|
20
|
+
@constexpr_function
|
|
21
|
+
def _is_power_of_two(i):
|
|
22
|
+
return (i & (i - 1)) == 0 and i != 0
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
# -----------------------
|
|
@@ -263,8 +264,8 @@ def _sum_combine(a, b):
|
|
|
263
264
|
# sum
|
|
264
265
|
|
|
265
266
|
|
|
266
|
-
|
|
267
|
-
|
|
267
|
+
@constexpr_function
|
|
268
|
+
def _pick_sum_dtype(in_dtype, dtype):
|
|
268
269
|
if dtype is not None:
|
|
269
270
|
return dtype
|
|
270
271
|
|
|
@@ -316,9 +317,9 @@ def _or_combine(x, y):
|
|
|
316
317
|
|
|
317
318
|
@core._tensor_member_fn
|
|
318
319
|
@jit
|
|
319
|
-
@core._add_reduction_docstr("
|
|
320
|
+
@core._add_reduction_docstr("reduce_or")
|
|
320
321
|
def reduce_or(input, axis, keep_dims=False):
|
|
321
|
-
core.static_assert(input.type.scalar.is_int(), "
|
|
322
|
+
core.static_assert(input.type.scalar.is_int(), "reduce_or only supported for integers")
|
|
322
323
|
return core.reduce(input, axis, _or_combine, keep_dims=keep_dims)
|
|
323
324
|
|
|
324
325
|
|
|
@@ -476,14 +477,13 @@ def bitonic_merge(x, dim: core.constexpr = None, descending: core.constexpr = co
|
|
|
476
477
|
return _bitonic_merge(x, n_dims, descending, n_dims)
|
|
477
478
|
|
|
478
479
|
|
|
480
|
+
@constexpr_function
|
|
479
481
|
def _get_flip_dim(dim, shape):
|
|
480
|
-
dim = core._unwrap_if_constexpr(dim)
|
|
481
|
-
shape = core._unwrap_if_constexpr(shape)
|
|
482
482
|
if dim is None:
|
|
483
483
|
dim = len(shape) - 1
|
|
484
484
|
if dim < 0: # flip doesn't work if dim < 0 because the xor-swap for loop will start/end at the wrong index
|
|
485
485
|
dim += len(shape)
|
|
486
|
-
return
|
|
486
|
+
return dim
|
|
487
487
|
|
|
488
488
|
|
|
489
489
|
@core._tensor_member_fn
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from triton.runtime import driver
|
|
2
|
+
from triton.runtime.jit import constexpr_function
|
|
3
|
+
|
|
4
|
+
__all__ = ["current_target"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def current_target():
|
|
8
|
+
try:
|
|
9
|
+
active_driver = driver.active
|
|
10
|
+
except RuntimeError:
|
|
11
|
+
# If there is no active driver, return None
|
|
12
|
+
return None
|
|
13
|
+
return active_driver.get_current_target()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
current_target.__triton_builtin__ = True
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@constexpr_function
|
|
20
|
+
def is_cuda():
|
|
21
|
+
target = current_target()
|
|
22
|
+
return target is not None and target.backend == "cuda"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@constexpr_function
|
|
26
|
+
def cuda_capability_geq(major, minor=0):
|
|
27
|
+
"""
|
|
28
|
+
Determines whether we have compute capability >= (major, minor) and
|
|
29
|
+
returns this as a constexpr boolean. This can be used for guarding
|
|
30
|
+
inline asm implementations that require a certain compute capability.
|
|
31
|
+
"""
|
|
32
|
+
target = current_target()
|
|
33
|
+
if target is None or target.backend != "cuda":
|
|
34
|
+
return False
|
|
35
|
+
assert isinstance(target.arch, int)
|
|
36
|
+
return target.arch >= major * 10 + minor
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@constexpr_function
|
|
40
|
+
def is_hip():
|
|
41
|
+
target = current_target()
|
|
42
|
+
return target is not None and target.backend == "hip"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@constexpr_function
|
|
46
|
+
def is_hip_cdna3():
|
|
47
|
+
target = current_target()
|
|
48
|
+
return target is not None and target.arch == "gfx942"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@constexpr_function
|
|
52
|
+
def is_hip_cdna4():
|
|
53
|
+
target = current_target()
|
|
54
|
+
return target is not None and target.arch == "gfx950"
|
triton/runtime/_allocation.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from typing import Optional, Protocol
|
|
2
|
+
from contextvars import ContextVar
|
|
2
3
|
|
|
3
4
|
|
|
4
5
|
class Buffer(Protocol):
|
|
@@ -20,7 +21,7 @@ class NullAllocator:
|
|
|
20
21
|
"Use triton.set_allocator to specify an allocator.")
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
_allocator: Allocator = NullAllocator()
|
|
24
|
+
_allocator: ContextVar[Allocator] = ContextVar("_allocator", default=NullAllocator())
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
def set_allocator(allocator: Allocator):
|
|
@@ -28,5 +29,16 @@ def set_allocator(allocator: Allocator):
|
|
|
28
29
|
The allocator function is called during kernel launch for kernels that
|
|
29
30
|
require additional global memory workspace.
|
|
30
31
|
"""
|
|
31
|
-
|
|
32
|
-
|
|
32
|
+
_allocator.set(allocator)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_profile_allocator: Allocator = ContextVar("_allocator", default=NullAllocator())
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def set_profile_allocator(allocator: Optional[Allocator]):
|
|
39
|
+
"""
|
|
40
|
+
The profile allocator function is called before kernel launch for kernels
|
|
41
|
+
that require additional global memory workspace.
|
|
42
|
+
"""
|
|
43
|
+
global _profile_allocator
|
|
44
|
+
_profile_allocator.set(allocator)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Callable, Optional
|
|
3
|
+
from concurrent.futures import Executor, as_completed, Future
|
|
4
|
+
from contextvars import ContextVar
|
|
5
|
+
|
|
6
|
+
active_mode: ContextVar[Optional[AsyncCompileMode]] = ContextVar("async_compile_active_mode", default=None)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FutureKernel:
|
|
10
|
+
|
|
11
|
+
def __init__(self, finalize_compile: Callable, future: Future):
|
|
12
|
+
self.finalize_compile = finalize_compile
|
|
13
|
+
self.kernel = None
|
|
14
|
+
self.future = future
|
|
15
|
+
|
|
16
|
+
def result(self):
|
|
17
|
+
if self.kernel is not None:
|
|
18
|
+
return self.kernel
|
|
19
|
+
|
|
20
|
+
kernel = self.future.result()
|
|
21
|
+
self.finalize_compile(kernel)
|
|
22
|
+
self.kernel = kernel
|
|
23
|
+
return kernel
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AsyncCompileMode:
|
|
27
|
+
|
|
28
|
+
def __init__(self, executor: Executor):
|
|
29
|
+
self.executor = executor
|
|
30
|
+
self.raw_futures = []
|
|
31
|
+
self.future_kernels = {}
|
|
32
|
+
|
|
33
|
+
def submit(self, key, compile_fn, finalize_fn):
|
|
34
|
+
future = self.future_kernels.get(key)
|
|
35
|
+
if future is not None:
|
|
36
|
+
return future
|
|
37
|
+
|
|
38
|
+
future = self.executor.submit(compile_fn)
|
|
39
|
+
future._key = key
|
|
40
|
+
self.raw_futures.append(future)
|
|
41
|
+
future_kernel = FutureKernel(finalize_fn, future)
|
|
42
|
+
self.future_kernels[key] = future_kernel
|
|
43
|
+
return future_kernel
|
|
44
|
+
|
|
45
|
+
def __enter__(self):
|
|
46
|
+
if active_mode.get() is not None:
|
|
47
|
+
raise RuntimeError("Another AsyncCompileMode is already active")
|
|
48
|
+
active_mode.set(self)
|
|
49
|
+
return self
|
|
50
|
+
|
|
51
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
52
|
+
# Finalize any outstanding compiles
|
|
53
|
+
for future in as_completed(self.raw_futures):
|
|
54
|
+
self.future_kernels[future._key].result()
|
|
55
|
+
active_mode.set(None)
|
triton/runtime/autotuner.py
CHANGED
|
@@ -9,9 +9,11 @@ from functools import cached_property
|
|
|
9
9
|
from typing import Dict, Tuple, List, Optional
|
|
10
10
|
|
|
11
11
|
from .. import knobs
|
|
12
|
-
from .jit import KernelInterface
|
|
12
|
+
from .jit import KernelInterface, JITFunction
|
|
13
13
|
from .errors import OutOfResources, PTXASError
|
|
14
14
|
from .driver import driver
|
|
15
|
+
from .cache import get_cache_manager, triton_key
|
|
16
|
+
from triton._C.libtriton import get_cache_invalidating_env_vars
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
class Autotuner(KernelInterface):
|
|
@@ -169,10 +171,7 @@ class Autotuner(KernelInterface):
|
|
|
169
171
|
bench_fn()
|
|
170
172
|
return False
|
|
171
173
|
|
|
172
|
-
from triton.
|
|
173
|
-
from triton.compiler.compiler import make_backend, triton_key
|
|
174
|
-
from triton.runtime.cache import get_cache_manager
|
|
175
|
-
from triton.runtime.jit import JITFunction
|
|
174
|
+
from triton.compiler.compiler import make_backend
|
|
176
175
|
|
|
177
176
|
fn = self.fn
|
|
178
177
|
while not isinstance(fn, JITFunction):
|
triton/runtime/build.py
CHANGED
|
@@ -56,10 +56,11 @@ def is_clang(cc):
|
|
|
56
56
|
return cc == "clang" or cc == "clang.exe"
|
|
57
57
|
|
|
58
58
|
|
|
59
|
-
def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries
|
|
59
|
+
def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str],
|
|
60
|
+
ccflags: list[str]) -> list[str]:
|
|
60
61
|
if is_msvc(cc):
|
|
61
62
|
out_base = os.path.splitext(out)[0]
|
|
62
|
-
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819"]
|
|
63
|
+
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/std:c11", "/wd4819"]
|
|
63
64
|
cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
|
|
64
65
|
cc_cmd += [f"/Fo{out_base + '.obj'}"]
|
|
65
66
|
cc_cmd += ["/link"]
|
|
@@ -79,16 +80,16 @@ def _cc_cmd(cc, src, out, include_dirs, library_dirs, libraries):
|
|
|
79
80
|
cc_cmd += [f'-l{lib}' for lib in libraries]
|
|
80
81
|
cc_cmd += [f"-L{dir}" for dir in library_dirs]
|
|
81
82
|
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
|
|
83
|
+
cc_cmd += ccflags
|
|
82
84
|
return cc_cmd
|
|
83
85
|
|
|
84
86
|
|
|
85
|
-
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
|
|
86
|
-
|
|
87
|
+
def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
|
|
88
|
+
ccflags: list[str]) -> str:
|
|
87
89
|
if impl := knobs.build.impl:
|
|
88
90
|
return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
|
|
89
91
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
|
90
92
|
so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
|
|
91
|
-
# try to avoid setuptools if possible
|
|
92
93
|
cc = get_cc()
|
|
93
94
|
# This function was renamed and made public in Python 3.10
|
|
94
95
|
if hasattr(sysconfig, 'get_default_scheme'):
|
|
@@ -113,10 +114,10 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
|
|
|
113
114
|
_, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
|
|
114
115
|
include_dirs = include_dirs + msvc_winsdk_inc_dirs
|
|
115
116
|
library_dirs = library_dirs + msvc_winsdk_lib_dirs
|
|
116
|
-
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
|
|
117
|
+
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries, ccflags)
|
|
117
118
|
|
|
118
119
|
try:
|
|
119
|
-
|
|
120
|
+
subprocess.check_call(cc_cmd)
|
|
120
121
|
except Exception as e:
|
|
121
122
|
print("Failed to compile. cc_cmd:", cc_cmd)
|
|
122
123
|
raise e
|
|
@@ -142,7 +143,8 @@ def _load_module_from_path(name: str, path: str) -> ModuleType:
|
|
|
142
143
|
|
|
143
144
|
|
|
144
145
|
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
|
|
145
|
-
include_dirs: list[str] | None = None, libraries: list[str] | None = None
|
|
146
|
+
include_dirs: list[str] | None = None, libraries: list[str] | None = None,
|
|
147
|
+
ccflags: list[str] | None = None) -> ModuleType:
|
|
146
148
|
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
|
|
147
149
|
cache = get_cache_manager(key)
|
|
148
150
|
suffix = sysconfig.get_config_var("EXT_SUFFIX")
|
|
@@ -159,7 +161,7 @@ def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None
|
|
|
159
161
|
src_path = os.path.join(tmpdir, name + ".c")
|
|
160
162
|
with open(src_path, "w") as f:
|
|
161
163
|
f.write(src)
|
|
162
|
-
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [])
|
|
164
|
+
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
|
|
163
165
|
with open(so, "rb") as f:
|
|
164
166
|
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
|
|
165
167
|
|
triton/runtime/cache.py
CHANGED
|
@@ -5,8 +5,10 @@ from abc import ABC, abstractmethod
|
|
|
5
5
|
from typing import Dict, List, Optional
|
|
6
6
|
import base64
|
|
7
7
|
import hashlib
|
|
8
|
+
import functools
|
|
9
|
+
import sysconfig
|
|
8
10
|
|
|
9
|
-
from
|
|
11
|
+
from triton import __version__, knobs
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class CacheManager(ABC):
|
|
@@ -272,3 +274,44 @@ def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
|
|
|
272
274
|
key = f"{key}-{kwargs.get(kw)}"
|
|
273
275
|
key = hashlib.sha256(key.encode("utf-8")).hexdigest()
|
|
274
276
|
return _base32(key)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@functools.lru_cache()
|
|
280
|
+
def triton_key():
|
|
281
|
+
import pkgutil
|
|
282
|
+
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
283
|
+
contents = []
|
|
284
|
+
# frontend
|
|
285
|
+
with open(__file__, "rb") as f:
|
|
286
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
287
|
+
# compiler
|
|
288
|
+
path_prefixes = [
|
|
289
|
+
(os.path.join(TRITON_PATH, "compiler"), "triton.compiler."),
|
|
290
|
+
(os.path.join(TRITON_PATH, "backends"), "triton.backends."),
|
|
291
|
+
]
|
|
292
|
+
for path, prefix in path_prefixes:
|
|
293
|
+
for lib in pkgutil.walk_packages([path], prefix=prefix):
|
|
294
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
295
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
296
|
+
|
|
297
|
+
# backend
|
|
298
|
+
libtriton_hash = hashlib.sha256()
|
|
299
|
+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
|
|
300
|
+
with open(os.path.join(TRITON_PATH, "_C", f"libtriton.{ext}"), "rb") as f:
|
|
301
|
+
while True:
|
|
302
|
+
chunk = f.read(1024**2)
|
|
303
|
+
if not chunk:
|
|
304
|
+
break
|
|
305
|
+
libtriton_hash.update(chunk)
|
|
306
|
+
contents.append(libtriton_hash.hexdigest())
|
|
307
|
+
# language
|
|
308
|
+
language_path = os.path.join(TRITON_PATH, 'language')
|
|
309
|
+
for lib in pkgutil.walk_packages([language_path], prefix="triton.language."):
|
|
310
|
+
with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f:
|
|
311
|
+
contents += [hashlib.sha256(f.read()).hexdigest()]
|
|
312
|
+
return f'{__version__}' + '-'.join(contents)
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_cache_key(src, backend, backend_options, env_vars):
|
|
316
|
+
key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
|
|
317
|
+
return key
|
triton/runtime/driver.py
CHANGED
|
@@ -2,8 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from ..backends import backends, DriverBase
|
|
4
4
|
|
|
5
|
-
from typing import Any, Callable, Generic, TypeVar, Union
|
|
6
|
-
|
|
7
5
|
|
|
8
6
|
def _create_driver() -> DriverBase:
|
|
9
7
|
active_drivers = [x.driver for x in backends.values() if x.driver.is_active()]
|
|
@@ -12,52 +10,29 @@ def _create_driver() -> DriverBase:
|
|
|
12
10
|
return active_drivers[0]()
|
|
13
11
|
|
|
14
12
|
|
|
15
|
-
T = TypeVar("T")
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class LazyProxy(Generic[T]):
|
|
19
|
-
|
|
20
|
-
def __init__(self, init_fn: Callable[[], T]) -> None:
|
|
21
|
-
self._init_fn = init_fn
|
|
22
|
-
self._obj: Union[T, None] = None
|
|
23
|
-
|
|
24
|
-
def _initialize_obj(self) -> T:
|
|
25
|
-
if self._obj is None:
|
|
26
|
-
self._obj = self._init_fn()
|
|
27
|
-
return self._obj
|
|
28
|
-
|
|
29
|
-
def __getattr__(self, name) -> Any:
|
|
30
|
-
return getattr(self._initialize_obj(), name)
|
|
31
|
-
|
|
32
|
-
def __setattr__(self, name: str, value: Any) -> None:
|
|
33
|
-
if name in ["_init_fn", "_obj"]:
|
|
34
|
-
super().__setattr__(name, value)
|
|
35
|
-
else:
|
|
36
|
-
setattr(self._initialize_obj(), name, value)
|
|
37
|
-
|
|
38
|
-
def __delattr__(self, name: str) -> None:
|
|
39
|
-
delattr(self._initialize_obj(), name)
|
|
40
|
-
|
|
41
|
-
def __repr__(self) -> str:
|
|
42
|
-
if self._obj is None:
|
|
43
|
-
return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>"
|
|
44
|
-
return repr(self._obj)
|
|
45
|
-
|
|
46
|
-
def __str__(self) -> str:
|
|
47
|
-
return str(self._initialize_obj())
|
|
48
|
-
|
|
49
|
-
|
|
50
13
|
class DriverConfig:
|
|
51
14
|
|
|
52
15
|
def __init__(self) -> None:
|
|
53
|
-
self.
|
|
54
|
-
self.
|
|
16
|
+
self._default: DriverBase | None = None
|
|
17
|
+
self._active: DriverBase | None = None
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def default(self) -> DriverBase:
|
|
21
|
+
if self._default is None:
|
|
22
|
+
self._default = _create_driver()
|
|
23
|
+
return self._default
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def active(self) -> DriverBase:
|
|
27
|
+
if self._active is None:
|
|
28
|
+
self._active = self.default
|
|
29
|
+
return self._active
|
|
55
30
|
|
|
56
31
|
def set_active(self, driver: DriverBase) -> None:
|
|
57
|
-
self.
|
|
32
|
+
self._active = driver
|
|
58
33
|
|
|
59
34
|
def reset_active(self) -> None:
|
|
60
|
-
self.
|
|
35
|
+
self._active = self.default
|
|
61
36
|
|
|
62
37
|
|
|
63
38
|
driver = DriverConfig()
|