triton-windows 3.2.0.post12__cp39-cp39-win_amd64.whl → 3.3.0a0.post12__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of triton-windows might be problematic. Click here for more details.
- triton/_C/libtriton.pyd +0 -0
- triton/__init__.py +3 -3
- triton/_internal_testing.py +59 -4
- triton/_utils.py +35 -0
- triton/backends/amd/compiler.py +121 -74
- triton/backends/amd/driver.py +77 -43
- triton/backends/amd/include/hip/amd_detail/amd_device_functions.h +28 -49
- triton/backends/amd/include/hip/amd_detail/amd_hip_atomic.h +35 -9
- triton/backends/amd/include/hip/amd_detail/amd_hip_bf16.h +761 -284
- triton/backends/amd/include/hip/amd_detail/amd_hip_cooperative_groups.h +9 -3
- triton/backends/amd/include/hip/amd_detail/amd_hip_fp8.h +1391 -0
- triton/backends/amd/include/hip/amd_detail/amd_hip_gl_interop.h +3 -3
- triton/backends/amd/include/hip/amd_detail/amd_warp_functions.h +44 -0
- triton/backends/amd/include/hip/amd_detail/amd_warp_sync_functions.h +288 -0
- triton/backends/amd/include/hip/amd_detail/hip_api_trace.hpp +110 -14
- triton/backends/amd/include/hip/amd_detail/hip_prof_str.h +504 -103
- triton/backends/amd/include/hip/amd_detail/hip_runtime_prof.h +2 -1
- triton/backends/amd/include/hip/amd_detail/host_defines.h +4 -0
- triton/backends/amd/include/hip/hip_ext.h +4 -2
- triton/backends/amd/include/hip/hip_fp8.h +33 -0
- triton/backends/amd/include/hip/hip_runtime_api.h +375 -33
- triton/backends/amd/include/hip/hip_version.h +3 -3
- triton/backends/amd/include/hip/hiprtc.h +25 -25
- triton/backends/amd/include/hsa/amd_hsa_elf.h +40 -14
- triton/backends/amd/include/hsa/hsa.h +11 -2
- triton/backends/amd/include/hsa/hsa_api_trace.h +30 -17
- triton/backends/amd/include/hsa/hsa_api_trace_version.h +68 -0
- triton/backends/amd/include/hsa/hsa_ext_amd.h +83 -27
- triton/backends/amd/include/hsa/hsa_ven_amd_aqlprofile.h +46 -46
- triton/backends/amd/include/hsa/hsa_ven_amd_pc_sampling.h +416 -0
- triton/backends/amd/include/roctracer/hip_ostream_ops.h +84 -4
- triton/backends/amd/include/roctracer/hsa_ostream_ops.h +260 -0
- triton/backends/amd/include/roctracer/hsa_prof_str.h +51 -19
- triton/backends/amd/lib/asanrtl.bc +0 -0
- triton/backends/compiler.py +25 -225
- triton/backends/driver.py +7 -2
- triton/backends/nvidia/bin/ptxas.exe +0 -0
- triton/backends/nvidia/compiler.py +135 -90
- triton/backends/nvidia/driver.c +0 -1
- triton/backends/nvidia/driver.py +135 -49
- triton/backends/nvidia/include/cuda.h +2162 -241
- triton/backends/nvidia/lib/x64/cuda.lib +0 -0
- triton/compiler/__init__.py +2 -2
- triton/compiler/code_generator.py +334 -231
- triton/compiler/compiler.py +77 -66
- triton/language/__init__.py +22 -5
- triton/language/core.py +448 -74
- triton/language/extra/cuda/_experimental_tma.py +3 -5
- triton/language/math.py +1 -1
- triton/language/random.py +2 -1
- triton/language/semantic.py +206 -52
- triton/language/standard.py +35 -18
- triton/runtime/_allocation.py +32 -0
- triton/runtime/autotuner.py +27 -32
- triton/runtime/build.py +1 -48
- triton/runtime/cache.py +6 -6
- triton/runtime/errors.py +10 -0
- triton/runtime/interpreter.py +179 -45
- triton/runtime/jit.py +149 -190
- triton/testing.py +39 -11
- triton/tools/compile.py +27 -20
- triton/tools/{compile.c → extra/cuda/compile.c} +1 -0
- triton/tools/mxfp.py +301 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/METADATA +5 -2
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/RECORD +68 -59
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/top_level.txt +2 -0
- /triton/tools/{compile.h → extra/cuda/compile.h} +0 -0
- {triton_windows-3.2.0.post12.dist-info → triton_windows-3.3.0a0.post12.dist-info}/WHEEL +0 -0
triton/language/core.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import contextmanager
|
|
|
5
5
|
from enum import Enum
|
|
6
6
|
from functools import partial, wraps
|
|
7
7
|
import typing
|
|
8
|
-
from typing import Union, Callable, List, Sequence, TypeVar, Optional
|
|
8
|
+
from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
|
|
9
9
|
import builtins
|
|
10
10
|
from ..runtime.jit import jit
|
|
11
11
|
import inspect
|
|
@@ -29,7 +29,6 @@ def builtin(fn: T) -> T:
|
|
|
29
29
|
@wraps(fn)
|
|
30
30
|
def wrapper(*args, **kwargs):
|
|
31
31
|
if "_builder" not in kwargs or kwargs["_builder"] is None:
|
|
32
|
-
print(kwargs)
|
|
33
32
|
raise ValueError("Did you forget to add @triton.jit ? "
|
|
34
33
|
"(`_builder` argument must be provided outside of JIT functions.)")
|
|
35
34
|
return fn(*args, **kwargs)
|
|
@@ -141,6 +140,7 @@ class constexpr:
|
|
|
141
140
|
self.value = value.value
|
|
142
141
|
else:
|
|
143
142
|
self.value = value
|
|
143
|
+
self.type = constexpr
|
|
144
144
|
|
|
145
145
|
def __repr__(self) -> str:
|
|
146
146
|
return f"constexpr[{self.value}]"
|
|
@@ -280,12 +280,40 @@ def check_bit_width(value, shift_value):
|
|
|
280
280
|
)
|
|
281
281
|
|
|
282
282
|
|
|
283
|
+
class base_value:
|
|
284
|
+
"""Base class of values that exist in the triton IR (i.e. not constexprs).
|
|
285
|
+
"""
|
|
286
|
+
type: base_type
|
|
287
|
+
|
|
288
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
289
|
+
"""Flatten frontend value into a sequence of mlir handles, which are appended
|
|
290
|
+
to the output list
|
|
291
|
+
"""
|
|
292
|
+
raise NotImplementedError
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class base_type:
|
|
296
|
+
|
|
297
|
+
def __eq__(self, other):
|
|
298
|
+
raise NotImplementedError("Types must implement __eq__")
|
|
299
|
+
|
|
300
|
+
def __ne__(self, other):
|
|
301
|
+
return not (self == other)
|
|
302
|
+
|
|
303
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
304
|
+
"""Build a frontend value with the current dtype, wrapping a list of existing handles.
|
|
305
|
+
cursor is the index of the first handle relevant to this value, and the function
|
|
306
|
+
should return the updated cursor position after any handles consumed by the created value.
|
|
307
|
+
"""
|
|
308
|
+
raise NotImplementedError
|
|
309
|
+
|
|
310
|
+
|
|
283
311
|
# -----------------------
|
|
284
312
|
# dtype
|
|
285
313
|
# -----------------------
|
|
286
314
|
|
|
287
315
|
|
|
288
|
-
class dtype:
|
|
316
|
+
class dtype(base_type):
|
|
289
317
|
SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
|
|
290
318
|
UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
|
|
291
319
|
FP_TYPES = ['fp8e4b15', 'fp8e4nv', 'fp8e4b8', 'fp8e5', 'fp8e5b16', 'fp16', 'bf16', 'fp32', 'fp64']
|
|
@@ -474,14 +502,15 @@ class dtype:
|
|
|
474
502
|
def is_const():
|
|
475
503
|
return False
|
|
476
504
|
|
|
505
|
+
@staticmethod
|
|
506
|
+
def is_tuple():
|
|
507
|
+
return False
|
|
508
|
+
|
|
477
509
|
def __eq__(self, other: dtype):
|
|
478
510
|
if not isinstance(other, dtype):
|
|
479
511
|
return False
|
|
480
512
|
return self.name == other.name
|
|
481
513
|
|
|
482
|
-
def __ne__(self, other: dtype):
|
|
483
|
-
return not self.__eq__(other)
|
|
484
|
-
|
|
485
514
|
def __hash__(self):
|
|
486
515
|
return hash((self.name, ))
|
|
487
516
|
|
|
@@ -549,6 +578,9 @@ class dtype:
|
|
|
549
578
|
"""Output of repr needs to be an evaluatable expression"""
|
|
550
579
|
return f'triton.language.{self.codegen_name()}'
|
|
551
580
|
|
|
581
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[base_value, int]:
|
|
582
|
+
return tensor(handles[cursor], self), cursor + 1
|
|
583
|
+
|
|
552
584
|
|
|
553
585
|
# Some functions have a param named `dtype`, which shadows the `dtype` class.
|
|
554
586
|
# We can't change the param name because it is part of function's public API.
|
|
@@ -587,9 +619,6 @@ class pointer_type(dtype):
|
|
|
587
619
|
return False
|
|
588
620
|
return self.element_ty == other.element_ty and self.address_space == other.address_space and self.const == other.const
|
|
589
621
|
|
|
590
|
-
def __ne__(self, other: pointer_type) -> bool:
|
|
591
|
-
return not self.__eq__(other)
|
|
592
|
-
|
|
593
622
|
@property
|
|
594
623
|
def scalar(self):
|
|
595
624
|
return self
|
|
@@ -609,9 +638,10 @@ class block_type(dtype):
|
|
|
609
638
|
|
|
610
639
|
# Note that block_type's shape is a list of int
|
|
611
640
|
# while tensor's shape is a list of constexpr.
|
|
641
|
+
assert (isinstance(shape, (list, tuple)))
|
|
612
642
|
|
|
613
643
|
# shape can be empty ([]) when an input is a 0D tensor.
|
|
614
|
-
self.shape = _unwrap_shape(shape)
|
|
644
|
+
self.shape = tuple(_unwrap_shape(shape))
|
|
615
645
|
if not self.shape:
|
|
616
646
|
raise TypeError('0d block_type is forbidden')
|
|
617
647
|
|
|
@@ -633,32 +663,53 @@ class block_type(dtype):
|
|
|
633
663
|
def get_block_shapes(self) -> List[int]:
|
|
634
664
|
return self.shape
|
|
635
665
|
|
|
636
|
-
def __eq__(self, other
|
|
666
|
+
def __eq__(self, other) -> bool:
|
|
637
667
|
if not isinstance(other, block_type):
|
|
638
668
|
return False
|
|
639
669
|
return self.element_ty == other.element_ty and self.shape == other.shape
|
|
640
670
|
|
|
641
|
-
def __ne__(self, other: block_type) -> bool:
|
|
642
|
-
return not self.__eq__(other)
|
|
643
|
-
|
|
644
671
|
@property
|
|
645
672
|
def scalar(self):
|
|
646
673
|
return self.element_ty
|
|
647
674
|
|
|
648
675
|
|
|
649
|
-
class
|
|
676
|
+
class tuple_type(base_type):
|
|
650
677
|
|
|
651
|
-
def __init__(self,
|
|
652
|
-
self.
|
|
653
|
-
self.
|
|
678
|
+
def __init__(self, types, fields=None):
|
|
679
|
+
self.types = types
|
|
680
|
+
self.fields = fields or [''] * len(types)
|
|
681
|
+
self.name = '[' + ','.join([f"{k}:{v}" for k, v in zip(self.fields, self.types)]) + ']'
|
|
654
682
|
|
|
655
683
|
def __str__(self):
|
|
656
|
-
return
|
|
684
|
+
return self.name
|
|
685
|
+
|
|
686
|
+
def __iter__(self):
|
|
687
|
+
return iter(self.types)
|
|
657
688
|
|
|
658
689
|
def to_ir(self, builder: ir.builder):
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
690
|
+
return [ty.to_ir(builder) for ty in self.types]
|
|
691
|
+
|
|
692
|
+
def __getitem__(self, index: int) -> dtype:
|
|
693
|
+
return self.types[index]
|
|
694
|
+
|
|
695
|
+
def is_tuple(self):
|
|
696
|
+
return True
|
|
697
|
+
|
|
698
|
+
def __eq__(self, other):
|
|
699
|
+
return type(self) is type(other) and self.types == other.types and self.fields == other.fields
|
|
700
|
+
|
|
701
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tuple, int]:
|
|
702
|
+
values = []
|
|
703
|
+
for ty in self.types:
|
|
704
|
+
value, cursor = ty._unflatten_ir(handles, cursor)
|
|
705
|
+
values.append(value)
|
|
706
|
+
return tuple(values, self), cursor
|
|
707
|
+
|
|
708
|
+
|
|
709
|
+
class slice_type(dtype):
|
|
710
|
+
|
|
711
|
+
def __init__(self):
|
|
712
|
+
self.name = 'slice_type'
|
|
662
713
|
|
|
663
714
|
|
|
664
715
|
# scalar types
|
|
@@ -708,20 +759,12 @@ def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
|
|
|
708
759
|
raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
|
|
709
760
|
|
|
710
761
|
|
|
711
|
-
class _value:
|
|
712
|
-
"""Base class of values that exist in the triton IR (i.e. not constexprs).
|
|
713
|
-
"""
|
|
714
|
-
|
|
715
|
-
def __init__(self, handle):
|
|
716
|
-
self.handle = handle
|
|
717
|
-
|
|
718
|
-
|
|
719
762
|
# -----------------------
|
|
720
763
|
# tensor
|
|
721
764
|
# -----------------------
|
|
722
765
|
|
|
723
766
|
|
|
724
|
-
class tensor(
|
|
767
|
+
class tensor(base_value):
|
|
725
768
|
"""Represents an N-dimensional array of values or pointers.
|
|
726
769
|
|
|
727
770
|
:code:`tensor` is the fundamental data structure in Triton programs. Most
|
|
@@ -743,8 +786,9 @@ class tensor(_value):
|
|
|
743
786
|
|
|
744
787
|
def __init__(self, handle, type: dtype):
|
|
745
788
|
"""Not called by user code."""
|
|
789
|
+
super().__init__()
|
|
746
790
|
# IR handle
|
|
747
|
-
|
|
791
|
+
self.handle = handle
|
|
748
792
|
# Block shape
|
|
749
793
|
self.shape = type.shape if type.is_block() else ()
|
|
750
794
|
self.numel = 1
|
|
@@ -754,7 +798,10 @@ class tensor(_value):
|
|
|
754
798
|
self.type = type # Tensor type (can be block_type)
|
|
755
799
|
# Following the practice in pytorch, dtype is scalar type
|
|
756
800
|
self.dtype = type.scalar
|
|
757
|
-
self.shape = [constexpr(s) for s in self.shape]
|
|
801
|
+
self.shape = tuple([constexpr(s) for s in self.shape])
|
|
802
|
+
|
|
803
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
804
|
+
handles.append(self.handle)
|
|
758
805
|
|
|
759
806
|
def __str__(self) -> str:
|
|
760
807
|
# ex. "float32[16, 32]"
|
|
@@ -968,13 +1015,16 @@ class tensor(_value):
|
|
|
968
1015
|
|
|
969
1016
|
@builtin
|
|
970
1017
|
def __getitem__(self, slices, _builder=None):
|
|
971
|
-
|
|
1018
|
+
import builtins
|
|
1019
|
+
if isinstance(slices, (builtins.slice, slice, constexpr)) or slices is None:
|
|
972
1020
|
slices = [slices]
|
|
1021
|
+
if isinstance(slices, tuple):
|
|
1022
|
+
slices = slices.values
|
|
973
1023
|
ret = self
|
|
974
1024
|
for dim, sl in enumerate(slices):
|
|
975
1025
|
if sl is None or isinstance(sl, constexpr) and sl.value is None:
|
|
976
1026
|
ret = semantic.expand_dims(ret, dim, _builder)
|
|
977
|
-
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
|
|
1027
|
+
elif isinstance(sl, (builtins.slice, slice)) and sl.start is None and sl.stop is None and sl.step is None:
|
|
978
1028
|
pass
|
|
979
1029
|
else:
|
|
980
1030
|
raise ValueError(f"unsupported tensor index: {sl}")
|
|
@@ -990,13 +1040,7 @@ class tensor(_value):
|
|
|
990
1040
|
"""
|
|
991
1041
|
Alias for :py:func:`tensor.cast`.
|
|
992
1042
|
"""
|
|
993
|
-
|
|
994
|
-
# just copy-paste the implementation of cast here. It's not too bad.
|
|
995
|
-
dtype = _unwrap_if_constexpr(dtype)
|
|
996
|
-
bitcast = _unwrap_if_constexpr(bitcast)
|
|
997
|
-
if bitcast:
|
|
998
|
-
return semantic.bitcast(self, dtype, _builder)
|
|
999
|
-
return semantic.cast(self, dtype, _builder, fp_downcast_rounding)
|
|
1043
|
+
return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder)
|
|
1000
1044
|
|
|
1001
1045
|
# Type stubs for functions added by the _tensor_member_fn decorator.
|
|
1002
1046
|
# (Unfortunately these can't be created automatically.)
|
|
@@ -1084,6 +1128,9 @@ class tensor(_value):
|
|
|
1084
1128
|
def associative_scan(self, axis, combine_fn, reverse=False) -> tensor:
|
|
1085
1129
|
...
|
|
1086
1130
|
|
|
1131
|
+
def gather(self, indices, axis) -> tensor:
|
|
1132
|
+
...
|
|
1133
|
+
|
|
1087
1134
|
def histogram(self, num_bins) -> tensor:
|
|
1088
1135
|
...
|
|
1089
1136
|
|
|
@@ -1111,7 +1158,7 @@ class tensor(_value):
|
|
|
1111
1158
|
def argmin(self, axis, tie_break_left=True, keep_dims=False) -> tensor:
|
|
1112
1159
|
...
|
|
1113
1160
|
|
|
1114
|
-
def sum(self, axis=None, keep_dims=False) -> tensor:
|
|
1161
|
+
def sum(self, axis=None, keep_dims=False, dtype=None) -> tensor:
|
|
1115
1162
|
...
|
|
1116
1163
|
|
|
1117
1164
|
def xor_sum(self, axis=None, keep_dims=False) -> tensor:
|
|
@@ -1130,6 +1177,223 @@ class tensor(_value):
|
|
|
1130
1177
|
...
|
|
1131
1178
|
|
|
1132
1179
|
|
|
1180
|
+
class tuple(base_value):
|
|
1181
|
+
|
|
1182
|
+
def __init__(self, args: list, type: tuple_type = None):
|
|
1183
|
+
self.values = [i for i in args]
|
|
1184
|
+
|
|
1185
|
+
def get_type(x):
|
|
1186
|
+
if isinstance(x, dtype):
|
|
1187
|
+
return dtype
|
|
1188
|
+
if isinstance(x, int):
|
|
1189
|
+
return constexpr
|
|
1190
|
+
return x.type
|
|
1191
|
+
|
|
1192
|
+
self.type = type or tuple_type([get_type(x) for x in self.values])
|
|
1193
|
+
|
|
1194
|
+
def __getitem__(self, idx: constexpr):
|
|
1195
|
+
if isinstance(idx, int):
|
|
1196
|
+
idx = constexpr(idx)
|
|
1197
|
+
if isinstance(idx, constexpr):
|
|
1198
|
+
return self.values[idx]
|
|
1199
|
+
else:
|
|
1200
|
+
import builtins
|
|
1201
|
+
assert isinstance(idx, (slice, builtins.slice))
|
|
1202
|
+
return tuple(self.values[idx.start:idx.stop:idx.step])
|
|
1203
|
+
|
|
1204
|
+
def __getattr__(self, name):
|
|
1205
|
+
return self.values[self.type.fields.index(name)]
|
|
1206
|
+
|
|
1207
|
+
# TODO: remove
|
|
1208
|
+
def __setitem__(self, idx: constexpr, value):
|
|
1209
|
+
if isinstance(idx, int):
|
|
1210
|
+
idx = constexpr(idx)
|
|
1211
|
+
assert isinstance(idx, constexpr)
|
|
1212
|
+
self.values[idx] = value
|
|
1213
|
+
|
|
1214
|
+
def __add__(self, other):
|
|
1215
|
+
if isinstance(other, list):
|
|
1216
|
+
other = tuple(other)
|
|
1217
|
+
return tuple(self.values + other.values)
|
|
1218
|
+
# return tuple(a + b for a, b in zip(self.values, other.values))
|
|
1219
|
+
|
|
1220
|
+
def __mul__(self, other):
|
|
1221
|
+
assert isinstance(other, constexpr)
|
|
1222
|
+
return tuple(self.values * other.value)
|
|
1223
|
+
|
|
1224
|
+
def __eq__(self, other):
|
|
1225
|
+
import builtins
|
|
1226
|
+
if isinstance(other, (list, builtins.tuple)):
|
|
1227
|
+
other = tuple(other)
|
|
1228
|
+
return constexpr(self.values == other.values)
|
|
1229
|
+
|
|
1230
|
+
def __hash__(self):
|
|
1231
|
+
import builtins
|
|
1232
|
+
return hash(builtins.tuple(self.values))
|
|
1233
|
+
|
|
1234
|
+
def __str__(self):
|
|
1235
|
+
return str([str(x) for x in self.values])
|
|
1236
|
+
|
|
1237
|
+
def __iter__(self):
|
|
1238
|
+
return iter(self.values)
|
|
1239
|
+
|
|
1240
|
+
def __len__(self):
|
|
1241
|
+
return len(self.values)
|
|
1242
|
+
|
|
1243
|
+
def _flatten_ir(self, handles: List[ir.value]):
|
|
1244
|
+
for v in self.values:
|
|
1245
|
+
v._flatten_ir(handles)
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
class slice:
|
|
1249
|
+
|
|
1250
|
+
def __init__(self, start, stop, step):
|
|
1251
|
+
self.start = start
|
|
1252
|
+
self.stop = stop
|
|
1253
|
+
self.step = step
|
|
1254
|
+
self.type = slice_type()
|
|
1255
|
+
|
|
1256
|
+
|
|
1257
|
+
class tensor_descriptor_base_type(base_type):
|
|
1258
|
+
|
|
1259
|
+
def __init__(self, block_type: block_type):
|
|
1260
|
+
self.block_type = block_type
|
|
1261
|
+
|
|
1262
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
|
|
1263
|
+
value = _experimental_tensor_descriptor_base(handles[cursor], self.block_type)
|
|
1264
|
+
return value, cursor + 1
|
|
1265
|
+
|
|
1266
|
+
def to_ir(self, builder: ir.builder):
|
|
1267
|
+
return builder.create_tensor_descriptor_type(self.block_type.to_ir(builder))
|
|
1268
|
+
|
|
1269
|
+
def __str__(self) -> str:
|
|
1270
|
+
# ex. "tensor_descriptor<float32[16, 32]>"
|
|
1271
|
+
return f"tensor_descriptor<{self.block_type}>"
|
|
1272
|
+
|
|
1273
|
+
def __eq__(self, other) -> bool:
|
|
1274
|
+
if type(other) is not type(self):
|
|
1275
|
+
return False
|
|
1276
|
+
return self.block_type == other.block_type
|
|
1277
|
+
|
|
1278
|
+
def __neq__(self, other) -> bool:
|
|
1279
|
+
return not (self == other)
|
|
1280
|
+
|
|
1281
|
+
|
|
1282
|
+
class _experimental_tensor_descriptor_base(base_value):
|
|
1283
|
+
""""
|
|
1284
|
+
A tensor descriptor with unknown shape and strides
|
|
1285
|
+
"""
|
|
1286
|
+
|
|
1287
|
+
def __init__(self, handle, block_type: block_type):
|
|
1288
|
+
"""Not called by user code."""
|
|
1289
|
+
super().__init__()
|
|
1290
|
+
|
|
1291
|
+
self.handle = handle # IR handle
|
|
1292
|
+
self.type = tensor_descriptor_base_type(block_type) # Tensor type (block_type)
|
|
1293
|
+
|
|
1294
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
1295
|
+
handles.append(self.handle)
|
|
1296
|
+
|
|
1297
|
+
@property
|
|
1298
|
+
def block_type(self):
|
|
1299
|
+
return self.type.block_type
|
|
1300
|
+
|
|
1301
|
+
@property
|
|
1302
|
+
def block_shape(self):
|
|
1303
|
+
return self.type.block_type.shape
|
|
1304
|
+
|
|
1305
|
+
@property
|
|
1306
|
+
def dtype(self):
|
|
1307
|
+
return self.type.block_type.element_ty
|
|
1308
|
+
|
|
1309
|
+
def __str__(self) -> str:
|
|
1310
|
+
return str(self.type)
|
|
1311
|
+
|
|
1312
|
+
@builtin
|
|
1313
|
+
def load(self, offsets: Sequence[constexpr | tensor], _builder=None) -> tensor:
|
|
1314
|
+
"""Load a block from the descriptor starting at the given element offsets.
|
|
1315
|
+
|
|
1316
|
+
Values outside of the tensor bounds will be filled with zeros.
|
|
1317
|
+
|
|
1318
|
+
:note: Offset must be a multiple of 16-bytes
|
|
1319
|
+
"""
|
|
1320
|
+
return semantic.descriptor_load(self, offsets, "", "", _builder)
|
|
1321
|
+
|
|
1322
|
+
@builtin
|
|
1323
|
+
def store(self, offsets: Sequence[constexpr | tensor], value: tensor, _builder=None) -> tensor:
|
|
1324
|
+
"""Store a block from the descriptor starting at the given element offsets.
|
|
1325
|
+
|
|
1326
|
+
Values outside of the tensor bounds will be ignored.
|
|
1327
|
+
|
|
1328
|
+
:note: Offset must be a multiple of 16-bytes
|
|
1329
|
+
"""
|
|
1330
|
+
return semantic.descriptor_store(self, value, offsets, _builder)
|
|
1331
|
+
|
|
1332
|
+
@builtin
|
|
1333
|
+
def gather(self, *args, _builder=None) -> tensor:
|
|
1334
|
+
"""Gather multiple descriptors worth of data"""
|
|
1335
|
+
assert len(args) == 2, f"descriptor gather only supports 2D indexing, but got {len(args)}"
|
|
1336
|
+
x_offsets = args[0]
|
|
1337
|
+
y_offset = args[1]
|
|
1338
|
+
return semantic.descriptor_gather(self, x_offsets, y_offset, "", "", _builder)
|
|
1339
|
+
|
|
1340
|
+
@builtin
|
|
1341
|
+
def scatter(self, value, *args, _builder=None) -> tensor:
|
|
1342
|
+
"""Scatter multiple descriptors worth of data"""
|
|
1343
|
+
assert len(args) == 2, f"descriptor scatter only supports 2D indexing, but got {len(args)}"
|
|
1344
|
+
x_offsets = args[0]
|
|
1345
|
+
y_offset = args[1]
|
|
1346
|
+
return semantic.descriptor_scatter(self, value, x_offsets, y_offset, _builder)
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
class tensor_descriptor_type(tensor_descriptor_base_type):
|
|
1350
|
+
|
|
1351
|
+
def __init__(self, block_type: block_type, shape_type: tuple_type, strides_type: tuple_type):
|
|
1352
|
+
self.block_type = block_type
|
|
1353
|
+
self.shape_type = shape_type
|
|
1354
|
+
self.strides_type = strides_type
|
|
1355
|
+
|
|
1356
|
+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[_experimental_tensor_descriptor_base, int]:
|
|
1357
|
+
handle = handles[cursor]
|
|
1358
|
+
cursor += 1
|
|
1359
|
+
shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
|
|
1360
|
+
strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
|
|
1361
|
+
shape = shape.values
|
|
1362
|
+
strides = strides.values
|
|
1363
|
+
value = _experimental_tensor_descriptor(handle, shape, strides, self.block_type)
|
|
1364
|
+
return value, cursor
|
|
1365
|
+
|
|
1366
|
+
def to_ir(self, builder: ir.builder):
|
|
1367
|
+
return [super().to_ir(builder), *self.shape_type.to_ir(builder), *self.strides_type.to_ir(builder)]
|
|
1368
|
+
|
|
1369
|
+
def __eq__(self, other):
|
|
1370
|
+
return super().__eq__(other) and (self.shape_type == other.shape_type) and (self.strides_type
|
|
1371
|
+
== other.strides_type)
|
|
1372
|
+
|
|
1373
|
+
|
|
1374
|
+
class _experimental_tensor_descriptor(_experimental_tensor_descriptor_base):
|
|
1375
|
+
"""A descriptor representing a tensor in global memory.
|
|
1376
|
+
"""
|
|
1377
|
+
|
|
1378
|
+
def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type):
|
|
1379
|
+
"""Not called by user code."""
|
|
1380
|
+
# IR handle
|
|
1381
|
+
super().__init__(handle, block_type)
|
|
1382
|
+
self.type = tensor_descriptor_type(
|
|
1383
|
+
block_type,
|
|
1384
|
+
shape_type=tuple_type([s.type for s in shape]),
|
|
1385
|
+
strides_type=tuple_type([s.type for s in strides]),
|
|
1386
|
+
)
|
|
1387
|
+
# Global shape
|
|
1388
|
+
self.shape = shape
|
|
1389
|
+
self.strides = strides
|
|
1390
|
+
|
|
1391
|
+
def _flatten_ir(self, handles: List[ir.value]) -> None:
|
|
1392
|
+
handles.append(self.handle)
|
|
1393
|
+
handles.extend(s.handle for s in self.shape)
|
|
1394
|
+
handles.extend(s.handle for s in self.strides)
|
|
1395
|
+
|
|
1396
|
+
|
|
1133
1397
|
def get_bool_env_var(var_name):
|
|
1134
1398
|
v = os.getenv(var_name, "0")
|
|
1135
1399
|
return v == "1" or v == "true" or v == "on"
|
|
@@ -1290,6 +1554,7 @@ def trans(input: tensor, *dims, _builder=None):
|
|
|
1290
1554
|
:py:func:`permute` is equivalent to this function, except it doesn't
|
|
1291
1555
|
have the special case when no permutation is specified.
|
|
1292
1556
|
"""
|
|
1557
|
+
dims = _unwrap_iterable(dims)
|
|
1293
1558
|
if not dims:
|
|
1294
1559
|
dims = (1, 0)
|
|
1295
1560
|
return semantic.permute(input, dims, _builder)
|
|
@@ -1467,7 +1732,7 @@ def expand_dims(input, axis, _builder=None):
|
|
|
1467
1732
|
"""
|
|
1468
1733
|
input = semantic.to_tensor(input, _builder)
|
|
1469
1734
|
axis = _constexpr_to_value(axis)
|
|
1470
|
-
axes = list(axis) if isinstance(axis, Sequence) else [axis]
|
|
1735
|
+
axes = list(axis) if isinstance(axis, (Sequence, tuple)) else [axis]
|
|
1471
1736
|
new_ndim = len(input.shape) + len(axes)
|
|
1472
1737
|
axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
|
|
1473
1738
|
|
|
@@ -1499,8 +1764,9 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
|
|
|
1499
1764
|
:type bitcast: bool, optional
|
|
1500
1765
|
"""
|
|
1501
1766
|
input = semantic.to_tensor(input, _builder)
|
|
1502
|
-
|
|
1503
|
-
|
|
1767
|
+
dtype = _constexpr_to_value(dtype)
|
|
1768
|
+
fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding)
|
|
1769
|
+
bitcast = _constexpr_to_value(bitcast)
|
|
1504
1770
|
if bitcast:
|
|
1505
1771
|
return semantic.bitcast(input, dtype, _builder)
|
|
1506
1772
|
return semantic.cast(input, dtype, _builder, fp_downcast_rounding)
|
|
@@ -1522,16 +1788,16 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
|
|
|
1522
1788
|
where the first dimension of each block represents the batch dimension.
|
|
1523
1789
|
|
|
1524
1790
|
:param input: The first tensor to be multiplied.
|
|
1525
|
-
:type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code
|
|
1791
|
+
:type input: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
|
|
1526
1792
|
:param other: The second tensor to be multiplied.
|
|
1527
|
-
:type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code
|
|
1793
|
+
:type other: 2D or 3D tensor of scalar-type in {:code:`int8`, :code:`float8_e5m2`, :code:`float16`, :code:`bfloat16`, :code:`float32`}
|
|
1528
1794
|
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
|
|
1529
1795
|
:type acc: 2D or 3D tensor of scalar-type in {:code:`float16`, :code:`float32`, :code:`int32`}
|
|
1530
1796
|
:param input_precision: How to exercise the Tensor Cores for f32 x f32. If
|
|
1531
1797
|
the device does not have Tensor Cores or the inputs are not of dtype f32,
|
|
1532
1798
|
this option is ignored. For devices that do have tensor cores, the
|
|
1533
1799
|
default precision is tf32.
|
|
1534
|
-
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`.
|
|
1800
|
+
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
|
|
1535
1801
|
:param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
|
|
1536
1802
|
Only one of :code:`input_precision` and :code:`allow_tf32` can be
|
|
1537
1803
|
specified (i.e. at least one must be :code:`None`).
|
|
@@ -1549,26 +1815,39 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
|
|
|
1549
1815
|
|
|
1550
1816
|
|
|
1551
1817
|
@builtin
|
|
1552
|
-
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32,
|
|
1818
|
+
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, out_dtype=float32,
|
|
1819
|
+
_builder=None):
|
|
1553
1820
|
"""
|
|
1554
1821
|
Returns the matrix product of two blocks in microscaling format.
|
|
1822
|
+
|
|
1555
1823
|
lhs and rhs use microscaling formats described here:
|
|
1556
1824
|
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
1825
|
+
|
|
1826
|
+
Software emulation enables targeting hardware architectures without native microscaling
|
|
1827
|
+
operation support. Right now for such case, microscaled lhs/rhs are upcasted to
|
|
1828
|
+
:code:`bf16` element type beforehand for dot computation, with one exception:
|
|
1829
|
+
for AMD CDNA3 specifically, if one of the inputs is of :code:`fp16` element type,
|
|
1830
|
+
the other input is also upcasted to :code:`fp16` element type instead.
|
|
1831
|
+
This behavior is experimental and may be subject to change in the future.
|
|
1832
|
+
|
|
1557
1833
|
:param lhs: The first tensor to be multiplied.
|
|
1558
|
-
:type lhs: 2D tensor
|
|
1834
|
+
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
|
|
1559
1835
|
:param lhs_scale: Scale factor for lhs tensor.
|
|
1560
|
-
:type lhs_scale:
|
|
1561
|
-
:param lhs_format: format of the lhs tensor
|
|
1836
|
+
:type lhs_scale: e8m0 type represented as an uint8 tensor.
|
|
1837
|
+
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
|
|
1838
|
+
:type lhs_format: str
|
|
1562
1839
|
:param rhs: The second tensor to be multiplied.
|
|
1563
|
-
:type rhs: 2D tensor
|
|
1840
|
+
:type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
|
|
1564
1841
|
:param rhs_scale: Scale factor for rhs tensor.
|
|
1565
|
-
:type rhs_scale:
|
|
1566
|
-
:param rhs_format: format of the rhs tensor
|
|
1842
|
+
:type rhs_scale: e8m0 type represented as an uint8 tensor.
|
|
1843
|
+
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`, :code:`fp16`}.
|
|
1844
|
+
:type rhs_format: str
|
|
1567
1845
|
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
|
|
1568
1846
|
"""
|
|
1569
1847
|
out_dtype = _constexpr_to_value(out_dtype)
|
|
1570
1848
|
assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment"
|
|
1571
|
-
return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype,
|
|
1849
|
+
return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, fast_math, out_dtype,
|
|
1850
|
+
_builder)
|
|
1572
1851
|
|
|
1573
1852
|
|
|
1574
1853
|
# -----------------------
|
|
@@ -1636,6 +1915,16 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
|
|
|
1636
1915
|
volatile, _builder)
|
|
1637
1916
|
|
|
1638
1917
|
|
|
1918
|
+
@builtin
|
|
1919
|
+
def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype,
|
|
1920
|
+
_builder=None) -> _experimental_tensor_descriptor_base:
|
|
1921
|
+
"""
|
|
1922
|
+
Reinterpret a generic pointer as a TMA-backed tensor descriptor object.
|
|
1923
|
+
"""
|
|
1924
|
+
block_ty = block_type(_constexpr_to_value(dtype), block_shape)
|
|
1925
|
+
return semantic.reinterpret_tensor_descriptor(desc_ptr, block_ty, _builder)
|
|
1926
|
+
|
|
1927
|
+
|
|
1639
1928
|
@builtin
|
|
1640
1929
|
def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=None):
|
|
1641
1930
|
"""
|
|
@@ -1644,8 +1933,8 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=
|
|
|
1644
1933
|
|
|
1645
1934
|
This loads a tensor of data based on the descriptor and offsets.
|
|
1646
1935
|
"""
|
|
1647
|
-
|
|
1648
|
-
return
|
|
1936
|
+
desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, shape, dtype, _builder=_builder)
|
|
1937
|
+
return desc.load(offsets, _builder=_builder)
|
|
1649
1938
|
|
|
1650
1939
|
|
|
1651
1940
|
@builtin
|
|
@@ -1656,7 +1945,8 @@ def _experimental_descriptor_store(desc_pointer, value, offsets, _builder=None):
|
|
|
1656
1945
|
|
|
1657
1946
|
This stores a tensor of data based on the descriptor and offsets.
|
|
1658
1947
|
"""
|
|
1659
|
-
|
|
1948
|
+
desc = _experimental_reinterpret_tensor_descriptor(desc_pointer, value.shape, value.dtype, _builder=_builder)
|
|
1949
|
+
return desc.store(offsets, value, _builder=_builder)
|
|
1660
1950
|
|
|
1661
1951
|
|
|
1662
1952
|
@_tensor_member_fn
|
|
@@ -1737,6 +2027,64 @@ def advance(base, offsets, _builder=None):
|
|
|
1737
2027
|
return semantic.advance(base, offsets, _builder)
|
|
1738
2028
|
|
|
1739
2029
|
|
|
2030
|
+
@builtin
|
|
2031
|
+
def _experimental_make_tensor_descriptor(
|
|
2032
|
+
base: tensor,
|
|
2033
|
+
shape: List[tensor],
|
|
2034
|
+
strides: List[tensor],
|
|
2035
|
+
block_shape: List[constexpr],
|
|
2036
|
+
_builder=None,
|
|
2037
|
+
) -> _experimental_tensor_descriptor:
|
|
2038
|
+
"""Make an experimental tensor descriptor object
|
|
2039
|
+
|
|
2040
|
+
:param base: the base pointer of the tensor, must be 16-byte aligned
|
|
2041
|
+
:param shape: A list of non-negative integers representing the tensor shape
|
|
2042
|
+
:param strides: A list of tensor strides. Leading dimensions must be multiples
|
|
2043
|
+
of 16-byte strides and the last dimension must be contiguous.
|
|
2044
|
+
:param block_shape: The shape of block to be loaded/stored from global memory
|
|
2045
|
+
|
|
2046
|
+
Notes
|
|
2047
|
+
*****
|
|
2048
|
+
On NVIDIA GPUs with TMA support, this will result in a TMA descriptor object
|
|
2049
|
+
and loads and stores from the descriptor will be backed by the TMA hardware.
|
|
2050
|
+
|
|
2051
|
+
Currently only 2-5 dimensional tensors are supported.
|
|
2052
|
+
|
|
2053
|
+
Example
|
|
2054
|
+
*******
|
|
2055
|
+
.. code-block:: python
|
|
2056
|
+
|
|
2057
|
+
@triton.jit
|
|
2058
|
+
def inplace_abs(in_out_ptr, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
|
|
2059
|
+
desc = tl._experimental_make_tensor_descriptor(
|
|
2060
|
+
in_out_ptr,
|
|
2061
|
+
shape=[M, N],
|
|
2062
|
+
strides=[N, 1],
|
|
2063
|
+
block_shape=[M_BLOCK, N_BLOCK],
|
|
2064
|
+
)
|
|
2065
|
+
|
|
2066
|
+
moffset = tl.program_id(0) * M_BLOCK
|
|
2067
|
+
noffset = tl.program_id(1) * N_BLOCK
|
|
2068
|
+
|
|
2069
|
+
value = desc.load([moffset, noffset])
|
|
2070
|
+
desc.store([moffset, noffset], tl.abs(value))
|
|
2071
|
+
|
|
2072
|
+
# TMA descriptors require a global memory allocation
|
|
2073
|
+
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
|
|
2074
|
+
return torch.empty(size, device="cuda", dtype=torch.int8)
|
|
2075
|
+
|
|
2076
|
+
triton.set_allocator(alloc_fn)
|
|
2077
|
+
|
|
2078
|
+
M, N = 256, 256
|
|
2079
|
+
x = torch.randn(M, N, device="cuda")
|
|
2080
|
+
M_BLOCK, N_BLOCK = 32, 32
|
|
2081
|
+
grid = (M / M_BLOCK, N / N_BLOCK)
|
|
2082
|
+
inplace_abs[grid](x, M, N, M_BLOCK, N_BLOCK)
|
|
2083
|
+
|
|
2084
|
+
"""
|
|
2085
|
+
return semantic.make_tensor_descriptor(base, shape, strides, block_shape, _builder)
|
|
2086
|
+
|
|
2087
|
+
|
|
1740
2088
|
# -----------------------
|
|
1741
2089
|
# Atomic Memory Operations
|
|
1742
2090
|
# -----------------------
|
|
@@ -1994,7 +2342,8 @@ def clamp(x, min, max, propagate_nan: constexpr = PropagateNan.NONE, _builder=No
|
|
|
1994
2342
|
# -----------------------
|
|
1995
2343
|
|
|
1996
2344
|
|
|
1997
|
-
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None
|
|
2345
|
+
def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None,
|
|
2346
|
+
dtype_arg: str = None) -> Callable[[T], T]:
|
|
1998
2347
|
|
|
1999
2348
|
def _decorator(func: T) -> T:
|
|
2000
2349
|
docstr = """
|
|
@@ -2014,6 +2363,10 @@ def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_a
|
|
|
2014
2363
|
docstr += f"""
|
|
2015
2364
|
:param {tie_break_arg}: if true, in case of a tie (i.e., multiple elements have the same {name} value), return the left-most index for values that aren't NaN
|
|
2016
2365
|
:type {tie_break_arg}: bool"""
|
|
2366
|
+
if dtype_arg is not None:
|
|
2367
|
+
docstr += f"""
|
|
2368
|
+
:param {dtype_arg}: the desired data type of the returned tensor. If specified, the input tensor is casted to :code:`{dtype_arg}` before the operation is performed. This is useful for preventing data overflows. If not specified, integer and bool dtypes are upcasted to :code:`tl.int32` and float dtypes are upcasted to at least :code:`tl.float32`.
|
|
2369
|
+
:type {dtype_arg}: tl.dtype"""
|
|
2017
2370
|
|
|
2018
2371
|
func.__doc__ = docstr.format(name=name)
|
|
2019
2372
|
return func
|
|
@@ -2047,14 +2400,12 @@ def reduce(input, axis, combine_fn, keep_dims=False, _builder=None, _generator=N
|
|
|
2047
2400
|
return reduce((input, ), axis, combine_fn, keep_dims=keep_dims, _builder=_builder, _generator=_generator)[0]
|
|
2048
2401
|
|
|
2049
2402
|
def make_combine_region(reduce_op):
|
|
2050
|
-
|
|
2051
|
-
prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
|
|
2052
|
-
|
|
2403
|
+
param_types = [t.type.scalar for t in input] * 2
|
|
2053
2404
|
region = reduce_op.get_region(0)
|
|
2054
2405
|
with _insertion_guard(_builder):
|
|
2055
|
-
|
|
2056
|
-
block = _builder.create_block_with_parent(region, param_types)
|
|
2057
|
-
args = [tensor(block.arg(i), ty) for i, ty in enumerate(
|
|
2406
|
+
to_ir = lambda T: T.to_ir(_builder)
|
|
2407
|
+
block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
|
|
2408
|
+
args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
|
|
2058
2409
|
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
|
2059
2410
|
if isinstance(results, tensor):
|
|
2060
2411
|
handles = [results.handle]
|
|
@@ -2148,14 +2499,12 @@ def associative_scan(input, axis, combine_fn, reverse=False, _builder=None, _gen
|
|
|
2148
2499
|
return associative_scan((input, ), axis, combine_fn, reverse, _builder=_builder, _generator=_generator)[0]
|
|
2149
2500
|
|
|
2150
2501
|
def make_combine_region(scan_op):
|
|
2151
|
-
|
|
2152
|
-
prototype = function_type(in_scalar_tys, in_scalar_tys * 2)
|
|
2153
|
-
|
|
2502
|
+
param_types = [t.type.scalar for t in input] * 2
|
|
2154
2503
|
region = scan_op.get_region(0)
|
|
2155
2504
|
with _insertion_guard(_builder):
|
|
2156
|
-
|
|
2157
|
-
block = _builder.create_block_with_parent(region, param_types)
|
|
2158
|
-
args = [tensor(block.arg(i), ty) for i, ty in enumerate(
|
|
2505
|
+
to_ir = lambda T: T.to_ir(_builder)
|
|
2506
|
+
block = _builder.create_block_with_parent(region, list(map(to_ir, param_types)))
|
|
2507
|
+
args = [tensor(block.arg(i), ty) for i, ty in enumerate(param_types)]
|
|
2159
2508
|
results = _generator.call_JitFunction(combine_fn, args, kwargs={})
|
|
2160
2509
|
if isinstance(results, tensor):
|
|
2161
2510
|
handles = [results.handle]
|
|
@@ -2184,6 +2533,23 @@ def histogram(input, num_bins, _builder=None, _generator=None):
|
|
|
2184
2533
|
return semantic.histogram(input, num_bins, _builder)
|
|
2185
2534
|
|
|
2186
2535
|
|
|
2536
|
+
@_tensor_member_fn
|
|
2537
|
+
@builtin
|
|
2538
|
+
def gather(src, index, axis, _builder=None):
|
|
2539
|
+
"""Gather from a tensor along a given dimension.
|
|
2540
|
+
|
|
2541
|
+
:param src: the source tensor
|
|
2542
|
+
:type src: Tensor
|
|
2543
|
+
:param index: the index tensor
|
|
2544
|
+
:type index: Tensor
|
|
2545
|
+
:param axis: the dimension to gather along
|
|
2546
|
+
:type axis: int
|
|
2547
|
+
|
|
2548
|
+
"""
|
|
2549
|
+
axis = _constexpr_to_value(axis)
|
|
2550
|
+
return semantic.gather(src, index, axis, _builder)
|
|
2551
|
+
|
|
2552
|
+
|
|
2187
2553
|
# -----------------------
|
|
2188
2554
|
# Compiler Hint Ops
|
|
2189
2555
|
# -----------------------
|
|
@@ -2565,9 +2931,15 @@ class range:
|
|
|
2565
2931
|
:param loop_unroll_factor: Tells the Triton IR level loop unroller how many
|
|
2566
2932
|
times to unroll a for loop that this range is used with. Less than 2 for
|
|
2567
2933
|
this value implies no unrolling.
|
|
2934
|
+
:param disallow_acc_multi_buffer: If true, prevent the accumulator of the dot
|
|
2935
|
+
operation in the loop to be multi-buffered, if applicable.
|
|
2936
|
+
:param flatten: automatically flatten the loop nest starting at this loop to
|
|
2937
|
+
create a single flattened loop. The compiler will try to pipeline the
|
|
2938
|
+
flattened loop which can avoid stage stalling.
|
|
2568
2939
|
"""
|
|
2569
2940
|
|
|
2570
|
-
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None
|
|
2941
|
+
def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_unroll_factor=None,
|
|
2942
|
+
disallow_acc_multi_buffer=False, flatten=False):
|
|
2571
2943
|
if step is None:
|
|
2572
2944
|
self.step = constexpr(1)
|
|
2573
2945
|
else:
|
|
@@ -2580,6 +2952,8 @@ class range:
|
|
|
2580
2952
|
self.end = arg2
|
|
2581
2953
|
self.num_stages = num_stages
|
|
2582
2954
|
self.loop_unroll_factor = loop_unroll_factor
|
|
2955
|
+
self.disallow_acc_multi_buffer = disallow_acc_multi_buffer
|
|
2956
|
+
self.flatten = flatten
|
|
2583
2957
|
|
|
2584
2958
|
def __iter__(self):
|
|
2585
2959
|
raise RuntimeError("tl.range can only be used in @triton.jit'd functions")
|