triton-windows 3.4.0.post20__cp311-cp311-win_amd64.whl → 3.5.0.post21__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/runtime/interpreter.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
import ast
|
|
3
3
|
import textwrap
|
|
4
4
|
import inspect
|
|
5
|
-
from typing import Tuple, List, Dict
|
|
5
|
+
from typing import Tuple, List, Dict, Callable
|
|
6
6
|
|
|
7
7
|
import math
|
|
8
8
|
import numpy as np
|
|
@@ -77,17 +77,19 @@ class BlockPointerHandle:
|
|
|
77
77
|
class TensorDescHandle:
|
|
78
78
|
|
|
79
79
|
def __init__(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
|
|
80
|
-
block_shape: List[int]):
|
|
80
|
+
block_shape: List[int], padding):
|
|
81
81
|
self.base = base
|
|
82
82
|
self.ndim = len(shape)
|
|
83
83
|
self.shape = shape
|
|
84
84
|
self.strides = strides
|
|
85
85
|
self.block_shape = block_shape
|
|
86
|
+
self.padding = padding
|
|
86
87
|
|
|
87
88
|
def validate(self):
|
|
88
89
|
assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
|
|
89
90
|
assert len(self.strides) == self.ndim
|
|
90
91
|
assert len(self.block_shape) == self.ndim
|
|
92
|
+
assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
|
|
91
93
|
|
|
92
94
|
for stride in self.strides[:-1]:
|
|
93
95
|
assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
|
|
@@ -663,6 +665,9 @@ class InterpreterBuilder:
|
|
|
663
665
|
else: # scalar
|
|
664
666
|
return TensorHandle(np.full(shape, arg.data, dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
665
667
|
|
|
668
|
+
def create_unsplat(self, arg):
|
|
669
|
+
return TensorHandle(np.full((1, ), arg.data[0], dtype=_get_np_dtype(arg.dtype)), arg.dtype.scalar)
|
|
670
|
+
|
|
666
671
|
def create_atomic_cas(self, ptr, cmp, val, sem, scope):
|
|
667
672
|
if sem not in self.ir_sem_to_interpreter_sem:
|
|
668
673
|
raise ValueError(f"unsupported semantic {sem}")
|
|
@@ -725,15 +730,9 @@ class InterpreterBuilder:
|
|
|
725
730
|
ret.offsets[i].data += offsets[i].data
|
|
726
731
|
return ret
|
|
727
732
|
|
|
728
|
-
def create_make_tensor_descriptor(
|
|
729
|
-
|
|
730
|
-
base
|
|
731
|
-
shape: List[TensorHandle],
|
|
732
|
-
strides: List[TensorHandle],
|
|
733
|
-
tensor_shape: List[int],
|
|
734
|
-
is_signed: bool,
|
|
735
|
-
):
|
|
736
|
-
desc = TensorDescHandle(base, shape, strides, tensor_shape)
|
|
733
|
+
def create_make_tensor_descriptor(self, base: TensorHandle, shape: List[TensorHandle], strides: List[TensorHandle],
|
|
734
|
+
tensor_shape: List[int], is_signed: bool, padding: str = "zero"):
|
|
735
|
+
desc = TensorDescHandle(base, shape, strides, tensor_shape, padding)
|
|
737
736
|
desc.validate()
|
|
738
737
|
return desc
|
|
739
738
|
|
|
@@ -741,7 +740,16 @@ class InterpreterBuilder:
|
|
|
741
740
|
eviction_policy):
|
|
742
741
|
assert isinstance(desc, TensorDescHandle)
|
|
743
742
|
ptrs, mask = desc.materialize_pointers(indices)
|
|
744
|
-
|
|
743
|
+
dtype_tt = ptrs.get_element_ty()
|
|
744
|
+
dtype_np = _get_np_dtype(dtype_tt)
|
|
745
|
+
padding = desc.padding
|
|
746
|
+
if padding == _ir.PADDING_OPTION.PAD_ZERO:
|
|
747
|
+
other = TensorHandle(np.zeros_like(ptrs.data, dtype=dtype_np), dtype_tt)
|
|
748
|
+
elif padding == _ir.PADDING_OPTION.PAD_NAN:
|
|
749
|
+
other = TensorHandle(np.full_like(ptrs.data, float('nan'), dtype=dtype_np), dtype_tt)
|
|
750
|
+
else:
|
|
751
|
+
raise ValueError(f"unsupported padding {padding}")
|
|
752
|
+
return self.create_masked_load(ptrs, mask, other, cache_modifier=cache_modifier,
|
|
745
753
|
eviction_policy=eviction_policy, is_volatile=False)
|
|
746
754
|
|
|
747
755
|
def create_descriptor_store(self, desc: TensorDescHandle, value: TensorHandle, indices: List[TensorHandle]):
|
|
@@ -934,9 +942,9 @@ class ReduceOps(ReduceScanOpInterface):
|
|
|
934
942
|
elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
|
|
935
943
|
return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
|
|
936
944
|
elif self.combine_fn == tl.standard._elementwise_max:
|
|
937
|
-
return self.min_max(input[0], val_reduce_op=np.
|
|
945
|
+
return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
|
|
938
946
|
elif self.combine_fn == tl.standard._elementwise_min:
|
|
939
|
-
return self.min_max(input[0], val_reduce_op=np.
|
|
947
|
+
return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
|
|
940
948
|
elif self.combine_fn == tl.standard._sum_combine:
|
|
941
949
|
return self.sum(input[0])
|
|
942
950
|
else:
|
|
@@ -1125,7 +1133,7 @@ def _tuple_create(arg, contents):
|
|
|
1125
1133
|
# TODO: wrap everything in triton tensors
|
|
1126
1134
|
def _implicit_cvt(arg):
|
|
1127
1135
|
if isinstance(arg, int):
|
|
1128
|
-
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
|
|
1136
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
|
|
1129
1137
|
dtype = np.int32
|
|
1130
1138
|
if -2**31 <= arg < 2**31:
|
|
1131
1139
|
dtype = np.int32
|
|
@@ -1140,7 +1148,7 @@ def _implicit_cvt(arg):
|
|
|
1140
1148
|
handle = TensorHandle(np.array([arg], dtype=dtype), ty)
|
|
1141
1149
|
return tl.tensor(handle, ty)
|
|
1142
1150
|
if hasattr(arg, "data_ptr"):
|
|
1143
|
-
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg))
|
|
1151
|
+
ty = tl.str_to_ty(triton.runtime.jit.mangle_type(arg), None)
|
|
1144
1152
|
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
|
|
1145
1153
|
return tl.tensor(handle, ty)
|
|
1146
1154
|
elif isinstance(arg, tuple):
|
|
@@ -1150,12 +1158,10 @@ def _implicit_cvt(arg):
|
|
|
1150
1158
|
assert arg.strides[-1] == 1
|
|
1151
1159
|
strides[-1] = tl.constexpr(1)
|
|
1152
1160
|
semantic = TritonSemantic(InterpreterBuilder())
|
|
1153
|
-
return semantic.make_tensor_descriptor(
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
block_shape=[tl.constexpr(b) for b in arg.block_shape],
|
|
1158
|
-
)
|
|
1161
|
+
return semantic.make_tensor_descriptor(base=_implicit_cvt(arg.base),
|
|
1162
|
+
shape=[_implicit_cvt(s) for s in arg.shape], strides=strides,
|
|
1163
|
+
block_shape=[tl.constexpr(b)
|
|
1164
|
+
for b in arg.block_shape], padding_option=arg.padding)
|
|
1159
1165
|
return arg
|
|
1160
1166
|
|
|
1161
1167
|
|
|
@@ -1198,6 +1204,7 @@ class GridExecutor:
|
|
|
1198
1204
|
arg.shape,
|
|
1199
1205
|
arg.strides,
|
|
1200
1206
|
arg.block_shape,
|
|
1207
|
+
arg.padding,
|
|
1201
1208
|
)
|
|
1202
1209
|
elif not hasattr(arg, "data_ptr"):
|
|
1203
1210
|
return arg
|
|
@@ -1368,11 +1375,12 @@ class FunctionRewriter:
|
|
|
1368
1375
|
|
|
1369
1376
|
class InterpretedFunction:
|
|
1370
1377
|
# Cache all rewritten functions
|
|
1371
|
-
rewritten_fn = {}
|
|
1378
|
+
rewritten_fn: Dict[Callable, Callable] = {}
|
|
1372
1379
|
|
|
1373
1380
|
def __init__(self, fn, **kwargs) -> None:
|
|
1374
1381
|
self.fn = fn
|
|
1375
1382
|
self.rewriter = FunctionRewriter(fn, **kwargs)
|
|
1383
|
+
self.kwargs = kwargs
|
|
1376
1384
|
|
|
1377
1385
|
def run(*args, **kwargs):
|
|
1378
1386
|
grid = kwargs["grid"]
|