triton-windows 3.4.0.post20__cp310-cp310-win_amd64.whl → 3.5.0.post21__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +8 -2
- triton/_filecheck.py +24 -14
- triton/_internal_testing.py +70 -4
- triton/_utils.py +3 -1
- triton/backends/amd/compiler.py +68 -60
- triton/backends/amd/driver.c +113 -44
- triton/backends/amd/driver.py +133 -57
- triton/backends/driver.py +13 -0
- triton/backends/nvidia/compiler.py +80 -22
- triton/backends/nvidia/driver.c +88 -15
- triton/backends/nvidia/driver.py +130 -123
- triton/compiler/__init__.py +5 -2
- triton/compiler/code_generator.py +270 -163
- triton/compiler/compiler.py +45 -62
- triton/experimental/gluon/__init__.py +3 -2
- triton/experimental/gluon/_runtime.py +9 -6
- triton/experimental/gluon/language/__init__.py +117 -16
- triton/experimental/gluon/language/_core.py +246 -68
- triton/experimental/gluon/language/_layouts.py +398 -45
- triton/experimental/gluon/language/_math.py +17 -9
- triton/experimental/gluon/language/_semantic.py +130 -37
- triton/experimental/gluon/language/_standard.py +55 -22
- triton/experimental/gluon/language/amd/__init__.py +4 -0
- triton/experimental/gluon/language/amd/_layouts.py +96 -0
- triton/experimental/gluon/language/amd/cdna3/__init__.py +100 -0
- triton/experimental/gluon/language/amd/cdna4/__init__.py +48 -0
- triton/experimental/gluon/language/amd/cdna4/async_copy.py +151 -0
- triton/experimental/gluon/language/extra/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/__init__.py +3 -0
- triton/experimental/gluon/language/nvidia/ampere/async_copy.py +74 -0
- triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +80 -0
- triton/experimental/gluon/language/nvidia/blackwell/__init__.py +192 -7
- triton/experimental/gluon/language/nvidia/blackwell/tma.py +20 -0
- triton/experimental/gluon/language/nvidia/hopper/__init__.py +124 -3
- triton/experimental/gluon/language/nvidia/hopper/mbarrier.py +20 -37
- triton/experimental/gluon/language/nvidia/hopper/tma.py +4 -3
- triton/experimental/gluon/nvidia/hopper.py +6 -1
- triton/knobs.py +132 -67
- triton/language/__init__.py +16 -10
- triton/language/core.py +163 -83
- triton/language/extra/cuda/gdc.py +6 -6
- triton/language/extra/hip/__init__.py +3 -1
- triton/language/extra/hip/libdevice.py +7 -0
- triton/language/extra/hip/utils.py +35 -0
- triton/language/extra/libdevice.py +4 -0
- triton/language/semantic.py +76 -23
- triton/language/standard.py +14 -14
- triton/language/target_info.py +54 -0
- triton/runtime/_allocation.py +15 -3
- triton/runtime/_async_compile.py +55 -0
- triton/runtime/autotuner.py +4 -5
- triton/runtime/build.py +11 -9
- triton/runtime/cache.py +44 -1
- triton/runtime/driver.py +16 -41
- triton/runtime/interpreter.py +31 -23
- triton/runtime/jit.py +318 -157
- triton/runtime/tcc/include/_mingw.h +8 -10
- triton/runtime/tcc/include/assert.h +5 -0
- triton/runtime/tcc/include/errno.h +1 -1
- triton/runtime/tcc/include/float.h +21 -3
- triton/runtime/tcc/include/iso646.h +36 -0
- triton/runtime/tcc/include/limits.h +5 -0
- triton/runtime/tcc/include/malloc.h +2 -2
- triton/runtime/tcc/include/math.h +21 -261
- triton/runtime/tcc/include/stdalign.h +16 -0
- triton/runtime/tcc/include/stdarg.h +5 -70
- triton/runtime/tcc/include/stdatomic.h +171 -0
- triton/runtime/tcc/include/stddef.h +7 -19
- triton/runtime/tcc/include/stdlib.h +15 -4
- triton/runtime/tcc/include/stdnoreturn.h +7 -0
- triton/runtime/tcc/include/sys/stat.h +2 -2
- triton/runtime/tcc/include/sys/types.h +5 -0
- triton/runtime/tcc/include/tcc/tcc_libm.h +444 -27
- triton/runtime/tcc/include/tccdefs.h +342 -0
- triton/runtime/tcc/include/tgmath.h +89 -0
- triton/runtime/tcc/include/uchar.h +33 -0
- triton/runtime/tcc/include/unistd.h +1 -0
- triton/runtime/tcc/include/winapi/qos.h +72 -0
- triton/runtime/tcc/include/winapi/shellapi.h +59 -0
- triton/runtime/tcc/include/winapi/winbase.h +9 -2
- triton/runtime/tcc/include/winapi/wincon.h +8 -0
- triton/runtime/tcc/include/winapi/windows.h +1 -1
- triton/runtime/tcc/include/winapi/winnls.h +778 -0
- triton/runtime/tcc/include/winapi/winnt.h +9 -7
- triton/runtime/tcc/include/winapi/winsock2.h +1474 -0
- triton/runtime/tcc/include/winapi/ws2ipdef.h +21 -0
- triton/runtime/tcc/include/winapi/ws2tcpip.h +391 -0
- triton/runtime/tcc/lib/libtcc1.a +0 -0
- triton/runtime/tcc/lib/python314.def +1800 -0
- triton/runtime/tcc/lib/python314t.def +1809 -0
- triton/runtime/tcc/libtcc.dll +0 -0
- triton/runtime/tcc/tcc.exe +0 -0
- triton/tools/compile.py +62 -14
- triton/tools/extra/cuda/compile.c +1 -0
- triton/tools/extra/hip/compile.cpp +66 -0
- triton/tools/extra/hip/compile.h +13 -0
- triton/tools/ragged_tma.py +92 -0
- triton/tools/tensor_descriptor.py +7 -9
- triton/windows_utils.py +42 -79
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/METADATA +3 -4
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/RECORD +106 -75
- triton/runtime/tcc/lib/libtcc1-64.a +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/WHEEL +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/entry_points.txt +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/licenses/LICENSE +0 -0
- {triton_windows-3.4.0.post20.dist-info → triton_windows-3.5.0.post21.dist-info}/top_level.txt +0 -0
triton/language/__init__.py
CHANGED
|
@@ -55,9 +55,10 @@ from .core import (
|
|
|
55
55
|
cat,
|
|
56
56
|
cast,
|
|
57
57
|
clamp,
|
|
58
|
+
condition,
|
|
58
59
|
const,
|
|
59
60
|
constexpr,
|
|
60
|
-
|
|
61
|
+
constexpr_type,
|
|
61
62
|
debug_barrier,
|
|
62
63
|
device_assert,
|
|
63
64
|
device_print,
|
|
@@ -85,6 +86,7 @@ from .core import (
|
|
|
85
86
|
join,
|
|
86
87
|
load,
|
|
87
88
|
make_block_ptr,
|
|
89
|
+
map_elementwise,
|
|
88
90
|
max_constancy,
|
|
89
91
|
max_contiguous,
|
|
90
92
|
maximum,
|
|
@@ -130,6 +132,7 @@ from .random import (
|
|
|
130
132
|
randn4x,
|
|
131
133
|
uint_to_uniform_float,
|
|
132
134
|
)
|
|
135
|
+
from . import target_info
|
|
133
136
|
|
|
134
137
|
__all__ = [
|
|
135
138
|
"PropagateNan",
|
|
@@ -165,9 +168,10 @@ __all__ = [
|
|
|
165
168
|
"cdiv",
|
|
166
169
|
"ceil",
|
|
167
170
|
"clamp",
|
|
171
|
+
"condition",
|
|
168
172
|
"const",
|
|
169
173
|
"constexpr",
|
|
170
|
-
"
|
|
174
|
+
"constexpr_type",
|
|
171
175
|
"cos",
|
|
172
176
|
"cumprod",
|
|
173
177
|
"cumsum",
|
|
@@ -210,6 +214,7 @@ __all__ = [
|
|
|
210
214
|
"log",
|
|
211
215
|
"log2",
|
|
212
216
|
"make_block_ptr",
|
|
217
|
+
"map_elementwise",
|
|
213
218
|
"math",
|
|
214
219
|
"max",
|
|
215
220
|
"max_constancy",
|
|
@@ -252,6 +257,7 @@ __all__ = [
|
|
|
252
257
|
"store",
|
|
253
258
|
"sum",
|
|
254
259
|
"swizzle2d",
|
|
260
|
+
"target_info",
|
|
255
261
|
"tensor",
|
|
256
262
|
"topk",
|
|
257
263
|
"trans",
|
|
@@ -271,12 +277,12 @@ __all__ = [
|
|
|
271
277
|
]
|
|
272
278
|
|
|
273
279
|
|
|
274
|
-
def str_to_ty(name):
|
|
280
|
+
def str_to_ty(name, c):
|
|
275
281
|
from builtins import tuple
|
|
276
282
|
|
|
277
283
|
if isinstance(name, tuple):
|
|
278
284
|
fields = type(name).__dict__.get("_fields", None)
|
|
279
|
-
return tuple_type([str_to_ty(x) for x in name], fields)
|
|
285
|
+
return tuple_type([str_to_ty(x, c) for x in name], fields)
|
|
280
286
|
|
|
281
287
|
if name[0] == "*":
|
|
282
288
|
name = name[1:]
|
|
@@ -284,17 +290,17 @@ def str_to_ty(name):
|
|
|
284
290
|
if name[0] == "k":
|
|
285
291
|
name = name[1:]
|
|
286
292
|
const = True
|
|
287
|
-
ty = str_to_ty(name)
|
|
293
|
+
ty = str_to_ty(name, c)
|
|
288
294
|
return pointer_type(element_ty=ty, const=const)
|
|
289
295
|
|
|
290
296
|
if name.startswith("tensordesc"):
|
|
291
297
|
inner = name.split("<")[1].rstrip(">")
|
|
292
|
-
dtype, rest = inner.split("[", maxsplit=
|
|
293
|
-
block_shape, rest = rest.split("]", maxsplit=
|
|
298
|
+
dtype, rest = inner.split("[", maxsplit=1)
|
|
299
|
+
block_shape, rest = rest.split("]", maxsplit=1)
|
|
294
300
|
block_shape = [int(s.strip()) for s in block_shape.rstrip("]").split(",")]
|
|
295
301
|
layout = rest.lstrip(",")
|
|
296
302
|
is_gluon = len(layout)
|
|
297
|
-
dtype = str_to_ty(dtype)
|
|
303
|
+
dtype = str_to_ty(dtype, None)
|
|
298
304
|
ndim = len(block_shape)
|
|
299
305
|
shape_type = tuple_type([int32] * ndim)
|
|
300
306
|
# FIXME: Last dim stride should be constexpr(1)
|
|
@@ -308,8 +314,8 @@ def str_to_ty(name):
|
|
|
308
314
|
return gluon_tensor_descriptor_type(block, shape_type, stride_type, layout)
|
|
309
315
|
return tensor_descriptor_type(block, shape_type, stride_type)
|
|
310
316
|
|
|
311
|
-
if name
|
|
312
|
-
return
|
|
317
|
+
if name.startswith("constexpr"):
|
|
318
|
+
return constexpr_type(c)
|
|
313
319
|
|
|
314
320
|
tys = {
|
|
315
321
|
"fp8e4nv": float8e4nv,
|
triton/language/core.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import math
|
|
3
4
|
from warnings import warn
|
|
4
5
|
from contextlib import contextmanager
|
|
5
6
|
from enum import Enum
|
|
@@ -9,7 +10,7 @@ from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
|
|
|
9
10
|
from dataclasses import dataclass
|
|
10
11
|
import builtins
|
|
11
12
|
from .. import knobs
|
|
12
|
-
from ..runtime.jit import
|
|
13
|
+
from ..runtime.jit import JITCallable
|
|
13
14
|
import inspect
|
|
14
15
|
|
|
15
16
|
from .._C.libtriton import ir
|
|
@@ -86,7 +87,7 @@ def _tensor_member_fn(fn: T) -> T:
|
|
|
86
87
|
if is_builtin(fn):
|
|
87
88
|
setattr(wrapper, TRITON_BUILTIN, True)
|
|
88
89
|
|
|
89
|
-
setattr(tensor, fn.__name__, fn if isinstance(fn,
|
|
90
|
+
setattr(tensor, fn.__name__, fn if isinstance(fn, JITCallable) else wrapper)
|
|
90
91
|
return fn
|
|
91
92
|
|
|
92
93
|
|
|
@@ -152,10 +153,10 @@ class base_value:
|
|
|
152
153
|
|
|
153
154
|
class base_type:
|
|
154
155
|
|
|
155
|
-
def __eq__(self, other):
|
|
156
|
+
def __eq__(self, other) -> bool:
|
|
156
157
|
raise NotImplementedError("Types must implement __eq__")
|
|
157
158
|
|
|
158
|
-
def __ne__(self, other):
|
|
159
|
+
def __ne__(self, other) -> bool:
|
|
159
160
|
return not (self == other)
|
|
160
161
|
|
|
161
162
|
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
@@ -178,10 +179,13 @@ class constexpr_type(base_type):
|
|
|
178
179
|
self.value = value
|
|
179
180
|
|
|
180
181
|
def __eq__(self, other):
|
|
181
|
-
return self.value == other.value
|
|
182
|
+
return isinstance(other, constexpr_type) and self.value == other.value
|
|
182
183
|
|
|
183
184
|
def __repr__(self) -> str:
|
|
184
|
-
return f"
|
|
185
|
+
return f"constexpr_type[{self.value}]"
|
|
186
|
+
|
|
187
|
+
def __hash__(self):
|
|
188
|
+
return hash(self.value)
|
|
185
189
|
|
|
186
190
|
def mangle(self) -> str:
|
|
187
191
|
return repr(self)
|
|
@@ -199,15 +203,17 @@ class constexpr(base_value):
|
|
|
199
203
|
"""
|
|
200
204
|
|
|
201
205
|
def __init__(self, value):
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
self.value = value
|
|
206
|
+
while isinstance(value, constexpr):
|
|
207
|
+
value = value.value
|
|
208
|
+
self.value = value
|
|
206
209
|
self.type = constexpr_type(value)
|
|
207
210
|
|
|
208
211
|
def __repr__(self) -> str:
|
|
209
212
|
return f"constexpr[{self.value}]"
|
|
210
213
|
|
|
214
|
+
def __hash__(self):
|
|
215
|
+
return hash((self.value, self.type))
|
|
216
|
+
|
|
211
217
|
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
212
218
|
return
|
|
213
219
|
|
|
@@ -334,32 +340,6 @@ class constexpr(base_value):
|
|
|
334
340
|
return self.value.__getitem__(*args)
|
|
335
341
|
|
|
336
342
|
|
|
337
|
-
def constexpr_function(f):
|
|
338
|
-
"""
|
|
339
|
-
Wraps an arbitrary Python function so that it can be called at
|
|
340
|
-
compile-time on constexpr arguments in a Triton function and
|
|
341
|
-
returns a constexpr result.
|
|
342
|
-
"""
|
|
343
|
-
|
|
344
|
-
@wraps(f)
|
|
345
|
-
def wrapper(*args, _semantic=None, **kwargs):
|
|
346
|
-
# de-constexpr arguments and discard the _semantic keyword argument:
|
|
347
|
-
args = [_unwrap_if_constexpr(x) for x in args]
|
|
348
|
-
kwargs = {k: _unwrap_if_constexpr(v) for (k, v) in kwargs.items()}
|
|
349
|
-
|
|
350
|
-
# call the raw Python function f:
|
|
351
|
-
res = f(*args, **kwargs)
|
|
352
|
-
|
|
353
|
-
# convert result back to a Triton constexpr:
|
|
354
|
-
return constexpr(res)
|
|
355
|
-
|
|
356
|
-
# disguise the function as a Triton builtin to avoid raising an error
|
|
357
|
-
# that we're calling a non-JIT function from within a Triton kernel:
|
|
358
|
-
wrapper.__triton_builtin__ = True
|
|
359
|
-
wrapper.__module__ = constexpr_function.__module__
|
|
360
|
-
return wrapper
|
|
361
|
-
|
|
362
|
-
|
|
363
343
|
CONSTEXPR_0 = constexpr(0)
|
|
364
344
|
|
|
365
345
|
|
|
@@ -572,7 +552,8 @@ class dtype(base_type):
|
|
|
572
552
|
def is_const():
|
|
573
553
|
return False
|
|
574
554
|
|
|
575
|
-
def __eq__(self, other
|
|
555
|
+
def __eq__(self, other) -> bool:
|
|
556
|
+
other = _unwrap_if_constexpr(other)
|
|
576
557
|
if not isinstance(other, dtype):
|
|
577
558
|
return False
|
|
578
559
|
return self.name == other.name
|
|
@@ -696,7 +677,8 @@ class pointer_type(dtype):
|
|
|
696
677
|
def is_const(self):
|
|
697
678
|
return self.const
|
|
698
679
|
|
|
699
|
-
def __eq__(self, other
|
|
680
|
+
def __eq__(self, other) -> bool:
|
|
681
|
+
other = _unwrap_if_constexpr(other)
|
|
700
682
|
if not isinstance(other, pointer_type):
|
|
701
683
|
return False
|
|
702
684
|
return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
|
|
@@ -753,6 +735,10 @@ class block_type(dtype):
|
|
|
753
735
|
def scalar(self):
|
|
754
736
|
return self.element_ty
|
|
755
737
|
|
|
738
|
+
@property
|
|
739
|
+
def nbytes(self):
|
|
740
|
+
return self.numel * (self.element_ty.primitive_bitwidth // 8)
|
|
741
|
+
|
|
756
742
|
def mangle(self) -> str:
|
|
757
743
|
elt = self.scalar.mangle()
|
|
758
744
|
shape = '_'.join(map(str, self.shape))
|
|
@@ -879,10 +865,7 @@ class tensor(base_value):
|
|
|
879
865
|
self.handle = handle
|
|
880
866
|
# Block shape
|
|
881
867
|
self.shape = type.shape if type.is_block() else ()
|
|
882
|
-
self.numel =
|
|
883
|
-
for s in self.shape:
|
|
884
|
-
self.numel *= s
|
|
885
|
-
self.numel = constexpr(self.numel)
|
|
868
|
+
self.numel = constexpr(math.prod(self.shape))
|
|
886
869
|
self.type = type # Tensor type (can be block_type)
|
|
887
870
|
# Following the practice in pytorch, dtype is scalar type
|
|
888
871
|
self.dtype = type.scalar
|
|
@@ -1268,19 +1251,20 @@ class tensor(base_value):
|
|
|
1268
1251
|
...
|
|
1269
1252
|
|
|
1270
1253
|
|
|
1271
|
-
|
|
1254
|
+
def _type_for_tuple_values(values, fields=None):
|
|
1255
|
+
return tuple_type([constexpr_type(x) if isinstance(x, (int, float, dtype)) else x.type for x in values], fields)
|
|
1272
1256
|
|
|
1273
|
-
def __init__(self, args: Sequence, type: tuple_type = None):
|
|
1274
|
-
self.values = [i for i in args]
|
|
1275
1257
|
|
|
1276
|
-
|
|
1277
|
-
if isinstance(x, dtype):
|
|
1278
|
-
return dtype
|
|
1279
|
-
if isinstance(x, (int, float)):
|
|
1280
|
-
return constexpr
|
|
1281
|
-
return x.type
|
|
1258
|
+
class tuple(base_value):
|
|
1282
1259
|
|
|
1283
|
-
|
|
1260
|
+
def __init__(self, args: Sequence, type: Optional[tuple_type] = None):
|
|
1261
|
+
self.values = [i for i in args]
|
|
1262
|
+
if isinstance(type, tuple_type):
|
|
1263
|
+
self.type = type
|
|
1264
|
+
elif type is not None: # make_template in ASTFunction.deserialize may pass us a list/tuple
|
|
1265
|
+
self.type = tuple_type(type)
|
|
1266
|
+
else:
|
|
1267
|
+
self.type = _type_for_tuple_values(self.values)
|
|
1284
1268
|
|
|
1285
1269
|
def __getitem__(self, idx: constexpr):
|
|
1286
1270
|
if isinstance(idx, int):
|
|
@@ -1295,11 +1279,11 @@ class tuple(base_value):
|
|
|
1295
1279
|
return self.values[self.type.fields.index(name)]
|
|
1296
1280
|
|
|
1297
1281
|
# TODO: remove
|
|
1298
|
-
def
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
assert isinstance(idx, constexpr)
|
|
1282
|
+
def _setitem(self, idx, value):
|
|
1283
|
+
idx = _unwrap_if_constexpr(idx)
|
|
1284
|
+
assert isinstance(idx, int)
|
|
1302
1285
|
self.values[idx] = value
|
|
1286
|
+
self.type = _type_for_tuple_values(self.values, self.type.fields)
|
|
1303
1287
|
|
|
1304
1288
|
def __add__(self, other):
|
|
1305
1289
|
other = _normalize_tuple(other)
|
|
@@ -1560,7 +1544,7 @@ def _aggregate(cls):
|
|
|
1560
1544
|
def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs):
|
|
1561
1545
|
# Call into the user-defined constructor.
|
|
1562
1546
|
instance = this_cls._get_instance()
|
|
1563
|
-
if isinstance(cls.__init__,
|
|
1547
|
+
if isinstance(cls.__init__, JITCallable):
|
|
1564
1548
|
raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
|
|
1565
1549
|
extra_kwargs = {}
|
|
1566
1550
|
if "_semantic" in inspect.signature(cls.__init__).parameters:
|
|
@@ -1594,7 +1578,7 @@ def _aggregate(cls):
|
|
|
1594
1578
|
[(name, getattr(self, name).type) for name in cls.__annotations__.keys()])
|
|
1595
1579
|
|
|
1596
1580
|
for (name, member) in inspect.getmembers(cls):
|
|
1597
|
-
if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member,
|
|
1581
|
+
if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable):
|
|
1598
1582
|
if name != "__init__":
|
|
1599
1583
|
setattr(aggregate_value, name, member)
|
|
1600
1584
|
|
|
@@ -1828,11 +1812,6 @@ def join(a, b, _semantic=None):
|
|
|
1828
1812
|
return _semantic.join(a, b)
|
|
1829
1813
|
|
|
1830
1814
|
|
|
1831
|
-
@jit
|
|
1832
|
-
def _take_first(a, b):
|
|
1833
|
-
return a
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
1815
|
def _unsplat(x, _semantic=None, _generator=None):
|
|
1837
1816
|
"""
|
|
1838
1817
|
Convert a single-element tensor to a scalar.
|
|
@@ -1843,10 +1822,7 @@ def _unsplat(x, _semantic=None, _generator=None):
|
|
|
1843
1822
|
for d in x.shape:
|
|
1844
1823
|
numel *= d
|
|
1845
1824
|
assert numel == 1, "can only unsplat single-element tensors"
|
|
1846
|
-
|
|
1847
|
-
x = _semantic.reshape(x, [1])
|
|
1848
|
-
x = typing.cast(tensor, reduce(x, 0, _take_first, _semantic=_semantic, _generator=_generator))
|
|
1849
|
-
return x
|
|
1825
|
+
return _semantic.unsplat(x)
|
|
1850
1826
|
|
|
1851
1827
|
|
|
1852
1828
|
@_tensor_member_fn
|
|
@@ -2252,6 +2228,7 @@ def make_tensor_descriptor(
|
|
|
2252
2228
|
shape: List[tensor],
|
|
2253
2229
|
strides: List[tensor],
|
|
2254
2230
|
block_shape: List[constexpr],
|
|
2231
|
+
padding_option="zero",
|
|
2255
2232
|
_semantic=None,
|
|
2256
2233
|
) -> tensor_descriptor:
|
|
2257
2234
|
"""Make a tensor descriptor object
|
|
@@ -2301,7 +2278,9 @@ def make_tensor_descriptor(
|
|
|
2301
2278
|
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
|
|
2302
2279
|
|
|
2303
2280
|
"""
|
|
2304
|
-
|
|
2281
|
+
|
|
2282
|
+
padding_option = _unwrap_if_constexpr(padding_option)
|
|
2283
|
+
return _semantic.make_tensor_descriptor(base, shape, strides, block_shape, padding_option)
|
|
2305
2284
|
|
|
2306
2285
|
|
|
2307
2286
|
# -----------------------
|
|
@@ -2784,6 +2763,79 @@ def gather(src, index, axis, _semantic=None):
|
|
|
2784
2763
|
return _semantic.gather(src, index, axis)
|
|
2785
2764
|
|
|
2786
2765
|
|
|
2766
|
+
@builtin
|
|
2767
|
+
def map_elementwise(
|
|
2768
|
+
scalar_fn: Callable[..., Tuple[tensor, ...]],
|
|
2769
|
+
*args: tensor,
|
|
2770
|
+
pack=1,
|
|
2771
|
+
_semantic=None,
|
|
2772
|
+
_generator=None,
|
|
2773
|
+
):
|
|
2774
|
+
'''
|
|
2775
|
+
Map a scalar function over a tensor.
|
|
2776
|
+
|
|
2777
|
+
The input tensors :code:`args` are implicitly broadcasted to the same shape.
|
|
2778
|
+
|
|
2779
|
+
This may be useful in allowing control flow over single elements in a tensor,
|
|
2780
|
+
for example a multi-branch function where one branch is more expensive. With
|
|
2781
|
+
:code:`tl.where` you are forced to calculate both sides of the branch, but
|
|
2782
|
+
with an if we only execute one side.
|
|
2783
|
+
|
|
2784
|
+
.. highlight:: python
|
|
2785
|
+
.. code-block:: python
|
|
2786
|
+
|
|
2787
|
+
@triton.jit
|
|
2788
|
+
def selu_scalar(x, alpha):
|
|
2789
|
+
if x > 0:
|
|
2790
|
+
return a
|
|
2791
|
+
else:
|
|
2792
|
+
return alpha * (tl.exp(x) - 1)
|
|
2793
|
+
|
|
2794
|
+
@triton.jit
|
|
2795
|
+
def selu(x, alpha):
|
|
2796
|
+
return tl.map_elementwise(selu_scalar, x, alpha)
|
|
2797
|
+
|
|
2798
|
+
:param scalar_fn: the function to map over.
|
|
2799
|
+
:param pack: the number of elements to be processed by one function call.
|
|
2800
|
+
:return: one tensor or a tuple of tensors, depending on the mapped function.
|
|
2801
|
+
'''
|
|
2802
|
+
# Build the block for the nested region first to discover the return types
|
|
2803
|
+
assert pack >= 1
|
|
2804
|
+
in_scalar_tys = [t.type.scalar for t in args]
|
|
2805
|
+
builder = _semantic.builder
|
|
2806
|
+
block = builder.new_block()
|
|
2807
|
+
scalar_args = []
|
|
2808
|
+
for i, ty in enumerate(in_scalar_tys):
|
|
2809
|
+
for j in builtins.range(pack):
|
|
2810
|
+
block.add_argument(ty.to_ir(builder))
|
|
2811
|
+
scalar_args.append(tensor(block.arg(i * pack + j), ty))
|
|
2812
|
+
|
|
2813
|
+
with _insertion_guard(builder):
|
|
2814
|
+
builder.set_insertion_point_to_start(block)
|
|
2815
|
+
scalar_results = _generator.call_JitFunction(scalar_fn, scalar_args, kwargs={})
|
|
2816
|
+
|
|
2817
|
+
is_single = isinstance(scalar_results, tensor)
|
|
2818
|
+
if is_single:
|
|
2819
|
+
scalar_results = scalar_results,
|
|
2820
|
+
|
|
2821
|
+
handles = [r.handle for r in scalar_results]
|
|
2822
|
+
builder.create_map_elementwise_ret(handles)
|
|
2823
|
+
|
|
2824
|
+
fn_result_types = [x.type for x in scalar_results]
|
|
2825
|
+
scalar_result_types = fn_result_types
|
|
2826
|
+
if pack > 1:
|
|
2827
|
+
scalar_result_types = fn_result_types[::pack]
|
|
2828
|
+
for offset in builtins.range(1, pack):
|
|
2829
|
+
assert scalar_result_types == fn_result_types[offset::pack], "type mismatch in unpacked results"
|
|
2830
|
+
|
|
2831
|
+
def make_elementwise_region(elementwise_op):
|
|
2832
|
+
region = elementwise_op.get_region(0)
|
|
2833
|
+
region.push_back(block)
|
|
2834
|
+
|
|
2835
|
+
result = _semantic.map_elementwise(args, scalar_result_types, pack, make_elementwise_region)
|
|
2836
|
+
return result[0] if is_single else result
|
|
2837
|
+
|
|
2838
|
+
|
|
2787
2839
|
# -----------------------
|
|
2788
2840
|
# Compiler Hint Ops
|
|
2789
2841
|
# -----------------------
|
|
@@ -2941,7 +2993,7 @@ def device_print(prefix, *args, hex=False, _semantic=None):
|
|
|
2941
2993
|
|
|
2942
2994
|
|
|
2943
2995
|
@builtin
|
|
2944
|
-
def device_assert(cond, msg="", _semantic=None):
|
|
2996
|
+
def device_assert(cond, msg="", mask=None, _semantic=None):
|
|
2945
2997
|
'''
|
|
2946
2998
|
Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
|
|
2947
2999
|
is set to a value besides :code:`0` in order for this to have any effect.
|
|
@@ -2960,7 +3012,10 @@ def device_assert(cond, msg="", _semantic=None):
|
|
|
2960
3012
|
:param msg: the message to print if the assertion fails. This is required to be a string literal.
|
|
2961
3013
|
'''
|
|
2962
3014
|
msg = _unwrap_if_constexpr(msg)
|
|
2963
|
-
|
|
3015
|
+
mask = _unwrap_if_constexpr(mask)
|
|
3016
|
+
if mask is not None:
|
|
3017
|
+
mask = _semantic.to_tensor(mask)
|
|
3018
|
+
return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask)
|
|
2964
3019
|
|
|
2965
3020
|
|
|
2966
3021
|
@builtin
|
|
@@ -3098,7 +3153,7 @@ def inline_asm_elementwise(asm: str, constraints: str, args: Sequence, dtype: Un
|
|
|
3098
3153
|
# -----------------------
|
|
3099
3154
|
|
|
3100
3155
|
|
|
3101
|
-
class static_range:
|
|
3156
|
+
class static_range(base_value):
|
|
3102
3157
|
"""
|
|
3103
3158
|
Iterator that counts upward forever.
|
|
3104
3159
|
|
|
@@ -3154,7 +3209,7 @@ class async_task:
|
|
|
3154
3209
|
self.builder.unset_async_task_ids()
|
|
3155
3210
|
|
|
3156
3211
|
|
|
3157
|
-
class range:
|
|
3212
|
+
class range(base_value):
|
|
3158
3213
|
"""
|
|
3159
3214
|
Iterator that counts upward forever.
|
|
3160
3215
|
|
|
@@ -3189,6 +3244,9 @@ class range:
|
|
|
3189
3244
|
The compiler will attempt to partition memory, MMA, and vector
|
|
3190
3245
|
operations in the loop into separate async partitions. This will
|
|
3191
3246
|
increase the total number of warps required by the kernel.
|
|
3247
|
+
:param disable_licm: Tells the compiler it shouldn't hoist loop invariant
|
|
3248
|
+
code outside the loop. This is often useful to avoid creating long liveranges
|
|
3249
|
+
within a loop.
|
|
3192
3250
|
|
|
3193
3251
|
Note that warp specialization is only supported on Blackwell GPUs and
|
|
3194
3252
|
only works on simple matmul loops. Support for arbitrary loops will be
|
|
@@ -3196,7 +3254,7 @@ class range:
|
|
|
3196
3254
|
"""
|
|
3197
3255
|
|
|
3198
3256
|
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
|
|
3199
|
-
disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False):
|
|
3257
|
+
disallow_acc_multi_buffer=False, flatten=False, warp_specialize=False, disable_licm=False):
|
|
3200
3258
|
if step is None:
|
|
3201
3259
|
self.step = constexpr(1)
|
|
3202
3260
|
else:
|
|
@@ -3212,6 +3270,7 @@ class range:
|
|
|
3212
3270
|
self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
|
|
3213
3271
|
self.flatten = flatten
|
|
3214
3272
|
self.warp_specialize = warp_specialize
|
|
3273
|
+
self.disable_licm = disable_licm
|
|
3215
3274
|
|
|
3216
3275
|
def __iter__(self):
|
|
3217
3276
|
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
|
|
@@ -3220,13 +3279,36 @@ class range:
|
|
|
3220
3279
|
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
|
|
3221
3280
|
|
|
3222
3281
|
|
|
3282
|
+
class condition(base_value):
|
|
3283
|
+
"""
|
|
3284
|
+
While loop condition wrapper.
|
|
3285
|
+
|
|
3286
|
+
.. highlight:: python
|
|
3287
|
+
.. code-block:: python
|
|
3288
|
+
|
|
3289
|
+
@triton.jit
|
|
3290
|
+
def kernel(...):
|
|
3291
|
+
while tl.condition(c, disable_licm)
|
|
3292
|
+
...
|
|
3293
|
+
:note: This is a special wrapper used to annotate while loops in the context of
|
|
3294
|
+
:code:`triton.jit` functions. It allows user to pass extra attributes to the compiler.
|
|
3295
|
+
:param disable_licm: Tells the compiler it shouldn't hoist loop invariant
|
|
3296
|
+
code outside the loop. This is often useful to avoid creating long liveranges
|
|
3297
|
+
within a loop.
|
|
3298
|
+
"""
|
|
3299
|
+
|
|
3300
|
+
def __init__(self, arg1, disable_licm=False):
|
|
3301
|
+
self.condition = arg1
|
|
3302
|
+
self.disable_licm = disable_licm
|
|
3303
|
+
|
|
3304
|
+
|
|
3223
3305
|
# -----------------------
|
|
3224
3306
|
# Extern functions
|
|
3225
3307
|
# -----------------------
|
|
3226
3308
|
|
|
3227
3309
|
|
|
3228
|
-
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict,
|
|
3229
|
-
|
|
3310
|
+
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_type: dtype, is_pure: bool,
|
|
3311
|
+
_semantic):
|
|
3230
3312
|
'''
|
|
3231
3313
|
Dispatch a function to a library
|
|
3232
3314
|
:param func: the function to dispatch
|
|
@@ -3234,7 +3316,7 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
|
|
3234
3316
|
:param lib_path: the path of the library
|
|
3235
3317
|
:param args: the arguments of the function
|
|
3236
3318
|
:param arg_type_symbol_dict: the type of the arguments
|
|
3237
|
-
:param
|
|
3319
|
+
:param ret_type: the type of the return value
|
|
3238
3320
|
:return: the return value of the function
|
|
3239
3321
|
'''
|
|
3240
3322
|
if len(arg_type_symbol_dict) == 0:
|
|
@@ -3261,9 +3343,6 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic
|
|
|
3261
3343
|
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
|
|
3262
3344
|
else:
|
|
3263
3345
|
symbol = arg_type_symbol_dict[arg_types][0]
|
|
3264
|
-
ret_type = arg_type_symbol_dict[arg_types][1]
|
|
3265
|
-
if ret_shape:
|
|
3266
|
-
ret_type = block_type(ret_type, ret_shape)
|
|
3267
3346
|
builder = _semantic.builder
|
|
3268
3347
|
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(builder), is_pure), ret_type)
|
|
3269
3348
|
|
|
@@ -3282,15 +3361,16 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
|
|
3282
3361
|
'''
|
|
3283
3362
|
dispatch_args = args.copy()
|
|
3284
3363
|
all_scalar = True
|
|
3285
|
-
ret_shape = None
|
|
3286
3364
|
arg_types = []
|
|
3287
3365
|
for i in builtins.range(len(dispatch_args)):
|
|
3288
3366
|
dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
|
|
3289
3367
|
arg_types.append(dispatch_args[i].dtype)
|
|
3290
3368
|
if dispatch_args[i].type.is_block():
|
|
3291
3369
|
all_scalar = False
|
|
3370
|
+
|
|
3371
|
+
arg_types = tuple(arg_types)
|
|
3372
|
+
ret_type = arg_type_symbol_dict[arg_types][1]
|
|
3292
3373
|
if len(arg_types) > 0:
|
|
3293
|
-
arg_types = tuple(arg_types)
|
|
3294
3374
|
arithmetic_check = True
|
|
3295
3375
|
# If there's a type tuple that is not supported by the library, we will do arithmetic check
|
|
3296
3376
|
if arg_types in arg_type_symbol_dict:
|
|
@@ -3305,9 +3385,9 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol
|
|
|
3305
3385
|
dispatch_args[i], _ = _semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg,
|
|
3306
3386
|
arithmetic_check=arithmetic_check)
|
|
3307
3387
|
if not all_scalar:
|
|
3308
|
-
|
|
3388
|
+
ret_type = broadcast_arg.type.with_element_ty(ret_type)
|
|
3309
3389
|
func = _semantic.builder.create_extern_elementwise
|
|
3310
|
-
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict,
|
|
3390
|
+
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_type, is_pure, _semantic)
|
|
3311
3391
|
|
|
3312
3392
|
|
|
3313
3393
|
def binary_op_type_legalization(lhs, rhs, semantic):
|
|
@@ -10,22 +10,22 @@ from triton.language import core
|
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
@core.extern
|
|
13
|
-
def gdc_wait(
|
|
13
|
+
def gdc_wait(_semantic=None):
|
|
14
14
|
"""
|
|
15
15
|
GDC wait is a blocking instruction that waits for all instructions in a prior kernel to complete before continuing.
|
|
16
16
|
This ensures all memory operations happening before the wait is visible to instructions after it,
|
|
17
17
|
e.g. if the prior kernel writes to address "x" the new values will be visible in this kernel after the wait.
|
|
18
18
|
|
|
19
|
-
This instruction is also safe to execute when
|
|
19
|
+
This instruction is also safe to execute when programmatic dependent launch is disabled.
|
|
20
20
|
|
|
21
21
|
See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
|
|
22
22
|
"""
|
|
23
23
|
core.inline_asm_elementwise("griddepcontrol.wait; // dummy $0", "=r", [], dtype=core.int32, is_pure=False, pack=1,
|
|
24
|
-
|
|
24
|
+
_semantic=_semantic)
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
@core.extern
|
|
28
|
-
def gdc_launch_dependents(
|
|
28
|
+
def gdc_launch_dependents(_semantic=None):
|
|
29
29
|
"""
|
|
30
30
|
This operation when launched with programmatic dependent launch signals that
|
|
31
31
|
the next program may launch once all programs in the current kernel
|
|
@@ -34,9 +34,9 @@ def gdc_launch_dependents(_builder=None):
|
|
|
34
34
|
Repeated calls to this function have no effect past the first call, and the first call should be
|
|
35
35
|
treated by the programmer as a hint to the runtime system to launch the next kernel.
|
|
36
36
|
|
|
37
|
-
This instruction is also safe to execute when
|
|
37
|
+
This instruction is also safe to execute when programmatic dependent launch is disabled.
|
|
38
38
|
|
|
39
39
|
See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-griddepcontrol for more details.
|
|
40
40
|
"""
|
|
41
41
|
core.inline_asm_elementwise("griddepcontrol.launch_dependents; // dummy $0", "=r", [], dtype=core.int32,
|
|
42
|
-
is_pure=False, pack=1,
|
|
42
|
+
is_pure=False, pack=1, _semantic=_semantic)
|
|
@@ -73,6 +73,13 @@ def fast_expf(arg0, _semantic=None):
|
|
|
73
73
|
}, is_pure=True, _semantic=_semantic)
|
|
74
74
|
|
|
75
75
|
|
|
76
|
+
@core.extern
|
|
77
|
+
def fast_tanhf(arg0, _semantic=None):
|
|
78
|
+
return core.extern_elementwise("", "", [arg0], {
|
|
79
|
+
(core.dtype("fp32"), ): ("__triton_hip_fast_tanhf", core.dtype("fp32")),
|
|
80
|
+
}, is_pure=True, _semantic=_semantic)
|
|
81
|
+
|
|
82
|
+
|
|
76
83
|
@core.extern
|
|
77
84
|
def fast_dividef(arg0, arg1, _semantic=None):
|
|
78
85
|
return core.extern_elementwise("", "", [arg0, arg1], {
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from triton.language import core
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@core.extern
|
|
5
|
+
def memrealtime(_semantic=None):
|
|
6
|
+
"""
|
|
7
|
+
Returns a 64-bit real time-counter value
|
|
8
|
+
"""
|
|
9
|
+
target_arch = _semantic.builder.options.arch
|
|
10
|
+
if 'gfx11' in target_arch or 'gfx12' in target_arch:
|
|
11
|
+
return core.inline_asm_elementwise(
|
|
12
|
+
"""
|
|
13
|
+
s_sendmsg_rtn_b64 $0, sendmsg(MSG_RTN_GET_REALTIME)
|
|
14
|
+
s_waitcnt lgkmcnt(0)
|
|
15
|
+
""",
|
|
16
|
+
"=r",
|
|
17
|
+
[],
|
|
18
|
+
dtype=core.int64,
|
|
19
|
+
is_pure=False,
|
|
20
|
+
pack=1,
|
|
21
|
+
_semantic=_semantic,
|
|
22
|
+
)
|
|
23
|
+
else:
|
|
24
|
+
return core.inline_asm_elementwise(
|
|
25
|
+
"""
|
|
26
|
+
s_memrealtime $0
|
|
27
|
+
s_waitcnt vmcnt(0)
|
|
28
|
+
""",
|
|
29
|
+
"=r",
|
|
30
|
+
[],
|
|
31
|
+
dtype=core.int64,
|
|
32
|
+
is_pure=False,
|
|
33
|
+
pack=1,
|
|
34
|
+
_semantic=_semantic,
|
|
35
|
+
)
|